Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
ce164bb1
Commit
ce164bb1
authored
Nov 23, 2023
by
baberabb
Browse files
fix imports
parent
667fc837
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
19 deletions
+26
-19
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+26
-19
No files found.
lm_eval/models/vllm_causallms.py
View file @
ce164bb1
...
@@ -8,11 +8,7 @@ from tqdm import tqdm
...
@@ -8,11 +8,7 @@ from tqdm import tqdm
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval
import
utils
from
lm_eval
import
utils
# TODO: Fix this once complete
# flake8: noqa
try
:
from
vllm
import
LLM
,
SamplingParams
except
ModuleNotFoundError
:
pass
@
register_model
(
"vllm"
)
@
register_model
(
"vllm"
)
...
@@ -34,6 +30,7 @@ class VLLM(LM):
...
@@ -34,6 +30,7 @@ class VLLM(LM):
max_length
:
int
=
None
,
max_length
:
int
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
from
vllm
import
LLM
,
SamplingParams
self
.
model
=
LLM
(
self
.
model
=
LLM
(
model
=
pretrained
,
model
=
pretrained
,
...
@@ -68,9 +65,17 @@ class VLLM(LM):
...
@@ -68,9 +65,17 @@ class VLLM(LM):
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
):
return
self
.
_max_gen_toks
return
self
.
_max_gen_toks
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
=
None
,
add_special_tokens
=
False
):
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
=
None
,
add_special_tokens
=
False
,
truncation
=
False
,
):
""" """
""" """
encoding
=
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
add_special_tokens
)
encoding
=
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
add_special_tokens
,
truncation
=
truncation
)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if
left_truncate_len
:
if
left_truncate_len
:
...
@@ -109,14 +114,14 @@ class VLLM(LM):
...
@@ -109,14 +114,14 @@ class VLLM(LM):
)
)
return
outputs
return
outputs
def
loglikelihood
(
self
,
requests
)
->
List
[
Tuple
[
float
,
bool
]]:
def
loglikelihood
(
self
,
requests
:
List
[
Instance
]
)
->
List
[
Tuple
[
float
,
bool
]]:
new_reqs
=
[]
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
if
context
==
""
:
if
context
==
""
:
# end of text as context
# end of text as context
context_enc
,
continuation_enc
=
[
context_enc
,
continuation_enc
=
[
self
.
eot_token_id
],
self
.
tok_encode
(
self
.
eot_token_id
continuation
],
self
.
tokenizer
.
tok_encode
(
continuation
)
)
else
:
else
:
context_enc
,
continuation_enc
=
self
.
tokenizer
(
context_enc
,
continuation_enc
=
self
.
tokenizer
(
[
context
,
continuation
],
[
context
,
continuation
],
...
@@ -129,7 +134,7 @@ class VLLM(LM):
...
@@ -129,7 +134,7 @@ class VLLM(LM):
return
self
.
_loglikelihood_tokens
(
new_reqs
)
return
self
.
_loglikelihood_tokens
(
new_reqs
)
def
loglikelihood_rolling
(
self
,
requests
)
->
List
[
float
]:
def
loglikelihood_rolling
(
self
,
requests
:
List
[
Instance
]
)
->
List
[
float
]:
loglikelihoods
=
[]
loglikelihoods
=
[]
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
]):
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
]):
...
@@ -256,7 +261,9 @@ class VLLM(LM):
...
@@ -256,7 +261,9 @@ class VLLM(LM):
return
grouper
.
get_original
(
res
)
return
grouper
.
get_original
(
res
)
def
_loglikelihood_tokens
(
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
:
bool
=
False
self
,
requests
:
List
[
Tuple
[
Tuple
[
str
,
str
],
List
[
int
],
List
[
int
]]],
disable_tqdm
:
bool
=
False
,
)
->
List
[
Tuple
[
float
,
bool
]]:
)
->
List
[
Tuple
[
float
,
bool
]]:
res
=
[]
res
=
[]
...
@@ -271,7 +278,7 @@ class VLLM(LM):
...
@@ -271,7 +278,7 @@ class VLLM(LM):
n
=
self
.
batch_size
,
n
=
self
.
batch_size
,
fn
=
None
,
fn
=
None
,
)
)
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
))
)
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
disable_tqdm
)
for
chunk
in
chunks
:
for
chunk
in
chunks
:
inps
=
[]
inps
=
[]
ctxlens
=
[]
ctxlens
=
[]
...
@@ -305,11 +312,11 @@ class VLLM(LM):
...
@@ -305,11 +312,11 @@ class VLLM(LM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
@
staticmethod
@
staticmethod
def
_parse_logprobs
(
tokens
:
List
,
outputs
,
ctxlen
:
int
):
def
_parse_logprobs
(
tokens
:
List
,
outputs
,
ctxlen
:
int
)
->
Tuple
[
float
,
bool
]
:
"""Process logprobs and tokens.
"""Process logprobs and tokens.
:param tokens: list
:param tokens: list
Tokens from
resp
ons
e
Tokens from
context+continuati
ons
:param outputs: RequestOutput
:param outputs: RequestOutput
Contains prompt
Contains prompt
:param ctxlen: int
:param ctxlen: int
...
@@ -321,13 +328,13 @@ class VLLM(LM):
...
@@ -321,13 +328,13 @@ class VLLM(LM):
Whether argmax matches given continuation exactly
Whether argmax matches given continuation exactly
"""
"""
# Extract the logprobs for the continuation tokens
# prompt_logprobs = [None, {}*len(context-1)]
continuation_logprobs_dicts
=
outputs
.
prompt_logprobs
continuation_logprobs_dicts
=
outputs
.
prompt_logprobs
# Calculate continuation_logprobs
# Calculate continuation_logprobs
# assume ctxlen always > 1
continuation_logprobs
=
sum
(
continuation_logprobs
=
sum
(
logprob_dict
.
get
(
token
)
# Use .get to avoid KeyError and default to 0
logprob_dict
.
get
(
token
)
for
token
,
logprob_dict
in
zip
(
for
token
,
logprob_dict
in
zip
(
tokens
[
ctxlen
:],
continuation_logprobs_dicts
[
ctxlen
:]
tokens
[
ctxlen
:],
continuation_logprobs_dicts
[
ctxlen
:]
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment