Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
223b9488
Commit
223b9488
authored
Jul 23, 2025
by
Baber
Browse files
types
parent
7cef4d38
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
132 additions
and
135 deletions
+132
-135
lm_eval/models/api_models.py
lm_eval/models/api_models.py
+67
-72
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+1
-2
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+26
-25
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+38
-36
No files found.
lm_eval/models/api_models.py
View file @
223b9488
from
__future__
import
annotations
import
abc
import
abc
import
asyncio
import
asyncio
import
copy
import
copy
...
@@ -8,16 +10,9 @@ from functools import cached_property
...
@@ -8,16 +10,9 @@ from functools import cached_property
from
typing
import
(
from
typing
import
(
TYPE_CHECKING
,
TYPE_CHECKING
,
Any
,
Any
,
Awaitable
,
Callable
,
Callable
,
Dict
,
Iterable
,
List
,
Literal
,
Literal
,
NamedTuple
,
NamedTuple
,
Optional
,
Tuple
,
Union
,
)
)
...
@@ -36,18 +31,21 @@ from importlib.util import find_spec
...
@@ -36,18 +31,21 @@ from importlib.util import find_spec
from
io
import
BytesIO
from
io
import
BytesIO
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.model
import
TemplateLM
from
lm_eval.api.model
import
TemplateLM
from
lm_eval.models.utils
import
Collator
,
chunks
,
configure_pad_token
from
lm_eval.models.utils
import
Collator
,
chunks
,
configure_pad_token
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
collections.abc
import
Awaitable
,
Iterable
from
PIL
import
Image
from
PIL
import
Image
from
lm_eval.api.instance
import
Instance
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
LogLikelihoodInputs
=
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]
LogLikelihoodInputs
=
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]
# utility class to keep track of json encoded chats
# utility class to keep track of json encoded chats
...
@@ -58,9 +56,7 @@ class JsonChatStr(NamedTuple):
...
@@ -58,9 +56,7 @@ class JsonChatStr(NamedTuple):
return
self
.
prompt
.
encode
(
encoding
)
return
self
.
prompt
.
encode
(
encoding
)
def
create_image_prompt
(
def
create_image_prompt
(
imgs
:
list
[
Image
.
Image
],
chat
:
dict
,
fmt
:
str
=
"PNG"
)
->
dict
:
imgs
:
list
[
"Image.Image"
],
chat
:
dict
,
fmt
:
str
=
"PNG"
)
->
dict
:
"""
"""
Parameters
Parameters
...
@@ -109,33 +105,32 @@ class TemplateAPI(TemplateLM):
...
@@ -109,33 +105,32 @@ class TemplateAPI(TemplateLM):
model
:
str
=
None
,
model
:
str
=
None
,
pretrained
:
str
=
None
,
# `model` takes precedence over `pretrained` when passed.
pretrained
:
str
=
None
,
# `model` takes precedence over `pretrained` when passed.
base_url
:
str
=
None
,
base_url
:
str
=
None
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer
:
str
|
None
=
None
,
# Loglikelihood tasks require a tokenizer to calculate context lengths,
# Loglikelihood tasks require a tokenizer to calculate context lengths,
# however the requests can be sent as a string if the API doesn't support token inputs.
# however the requests can be sent as a string if the API doesn't support token inputs.
# use tokenized_requests=False
# use tokenized_requests=False
tokenizer_backend
:
Optional
[
tokenizer_backend
:
Literal
[
"tiktoken"
,
"huggingface"
,
"None"
,
"none"
]
Literal
[
"tiktoken"
,
"huggingface"
,
"None"
,
"none"
]
|
None
=
"huggingface"
,
]
=
"huggingface"
,
truncate
:
bool
=
False
,
truncate
:
bool
=
False
,
# number of concurrent requests. More useful if not batching
# number of concurrent requests. More useful if not batching
num_concurrent
:
int
=
1
,
num_concurrent
:
int
=
1
,
max_retries
:
int
=
3
,
max_retries
:
int
=
3
,
max_gen_toks
:
int
=
256
,
max_gen_toks
:
int
=
256
,
batch_size
:
Union
[
str
,
int
]
=
1
,
batch_size
:
str
|
int
=
1
,
seed
:
int
=
1234
,
seed
:
int
=
1234
,
max_length
:
Optional
[
int
]
=
2048
,
max_length
:
int
|
None
=
2048
,
add_bos_token
:
bool
=
False
,
add_bos_token
:
bool
=
False
,
custom_prefix_token_id
:
int
=
None
,
custom_prefix_token_id
:
int
=
None
,
# send the requests as tokens or strings
# send the requests as tokens or strings
tokenized_requests
:
bool
=
True
,
tokenized_requests
:
bool
=
True
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
revision
:
Optional
[
str
]
=
"main"
,
revision
:
str
|
None
=
"main"
,
use_fast_tokenizer
:
bool
=
True
,
use_fast_tokenizer
:
bool
=
True
,
verify_certificate
:
bool
=
True
,
verify_certificate
:
bool
=
True
,
eos_string
:
str
=
None
,
eos_string
:
str
=
None
,
# timeout in seconds
# timeout in seconds
timeout
:
int
=
300
,
timeout
:
int
=
300
,
header
:
Optional
[
D
ict
[
str
,
str
]
]
=
None
,
header
:
d
ict
[
str
,
str
]
|
None
=
None
,
max_images
:
int
=
1
,
max_images
:
int
=
1
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
...
@@ -232,12 +227,12 @@ class TemplateAPI(TemplateLM):
...
@@ -232,12 +227,12 @@ class TemplateAPI(TemplateLM):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
_create_payload
(
def
_create_payload
(
self
,
self
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
dict
]
,
L
ist
[
str
]
,
str
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
dict
]
|
l
ist
[
str
]
|
str
,
*
,
*
,
generate
:
bool
=
True
,
generate
:
bool
=
True
,
gen_kwargs
:
Optional
[
dict
]
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
seed
:
int
=
1234
,
seed
:
int
=
1234
,
eos
:
str
=
None
,
eos
:
str
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
dict
:
)
->
dict
:
"""This method is responsible for creating the json payload that will be sent to the API."""
"""This method is responsible for creating the json payload that will be sent to the API."""
...
@@ -245,9 +240,9 @@ class TemplateAPI(TemplateLM):
...
@@ -245,9 +240,9 @@ class TemplateAPI(TemplateLM):
def
create_message
(
def
create_message
(
self
,
self
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
str
]
,
L
ist
[
JsonChatStr
]
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
str
]
|
l
ist
[
JsonChatStr
],
generate
=
False
,
generate
=
False
,
)
->
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
dict
]
,
L
ist
[
str
]
,
str
]
:
)
->
l
ist
[
l
ist
[
int
]]
|
l
ist
[
dict
]
|
l
ist
[
str
]
|
str
:
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if
isinstance
(
messages
[
0
],
JsonChatStr
):
if
isinstance
(
messages
[
0
],
JsonChatStr
):
# for chat completions we need to decode the json string to list[dict,...]
# for chat completions we need to decode the json string to list[dict,...]
...
@@ -276,17 +271,17 @@ class TemplateAPI(TemplateLM):
...
@@ -276,17 +271,17 @@ class TemplateAPI(TemplateLM):
@
staticmethod
@
staticmethod
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
parse_logprobs
(
def
parse_logprobs
(
outputs
:
Union
[
Any
,
L
ist
[
Any
]
]
,
outputs
:
Any
|
l
ist
[
Any
],
tokens
:
L
ist
[
L
ist
[
int
]]
=
None
,
tokens
:
l
ist
[
l
ist
[
int
]]
|
None
=
None
,
ctxlen
:
L
ist
[
int
]
=
None
,
ctxlen
:
l
ist
[
int
]
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
@
staticmethod
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
parse_generations
(
outputs
:
Union
[
Any
,
L
ist
[
Any
]
]
,
**
kwargs
)
->
L
ist
[
str
]:
def
parse_generations
(
outputs
:
Any
|
l
ist
[
Any
],
**
kwargs
)
->
l
ist
[
str
]:
"""Method used to parse the generations from the (batched) API response. This method should return a list of str"""
"""Method used to parse the generations from the (batched) API response. This method should return a list of str"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -303,14 +298,15 @@ class TemplateAPI(TemplateLM):
...
@@ -303,14 +298,15 @@ class TemplateAPI(TemplateLM):
@
property
@
property
def
tokenizer_name
(
self
)
->
str
:
def
tokenizer_name
(
self
)
->
str
:
"""Must be defined for LM subclasses which implement Chat Templating.
"""Must be defined for LM subclasses which implement Chat Templating.
Should return the name of the tokenizer or chat template used.
Should return the name of the tokenizer or chat template used.
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
"""
"""
return
""
return
""
def
apply_chat_template
(
def
apply_chat_template
(
self
,
chat_history
:
L
ist
[
D
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
self
,
chat_history
:
l
ist
[
d
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
)
->
Union
[
str
,
JsonChatStr
]
:
)
->
str
|
JsonChatStr
:
"""Applies a chat template to a list of chat history between user and model."""
"""Applies a chat template to a list of chat history between user and model."""
if
self
.
tokenizer_backend
==
"huggingface"
and
self
.
tokenized_requests
:
if
self
.
tokenizer_backend
==
"huggingface"
and
self
.
tokenized_requests
:
return
self
.
tokenizer
.
apply_chat_template
(
return
self
.
tokenizer
.
apply_chat_template
(
...
@@ -319,33 +315,32 @@ class TemplateAPI(TemplateLM):
...
@@ -319,33 +315,32 @@ class TemplateAPI(TemplateLM):
add_generation_prompt
=
add_generation_prompt
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
not
add_generation_prompt
,
continue_final_message
=
not
add_generation_prompt
,
)
)
else
:
# bit of a hack. We'll load back before sending to the API
# bit of a hack. We'll load back before sending to the API
return
JsonChatStr
(
return
JsonChatStr
(
json
.
dumps
(
json
.
dumps
(
[{
**
item
,
"type"
:
"text"
}
for
item
in
chat_history
],
[{
**
item
,
"type"
:
"text"
}
for
item
in
chat_history
],
ensure_ascii
=
False
,
ensure_ascii
=
False
,
)
)
)
)
@
cached_property
@
cached_property
def
eot_token_id
(
self
)
->
Optional
[
int
]
:
def
eot_token_id
(
self
)
->
int
|
None
:
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
return
None
return
None
else
:
else
:
if
self
.
tokenizer_backend
==
"huggingface"
:
if
self
.
tokenizer_backend
==
"huggingface"
:
return
self
.
tokenizer
.
eos_token_id
return
self
.
tokenizer
.
eos_token_id
el
if
self
.
tokenizer_backend
==
"tiktoken"
:
if
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
eot_token
return
self
.
tokenizer
.
eot_token
@
cached_property
@
cached_property
def
eos_string
(
self
)
->
Optional
[
str
]
:
def
eos_string
(
self
)
->
str
|
None
:
if
self
.
_eos_string
:
if
self
.
_eos_string
:
return
self
.
_eos_string
return
self
.
_eos_string
el
if
self
.
tokenizer
is
not
None
:
if
self
.
tokenizer
is
not
None
:
if
self
.
tokenizer_backend
==
"huggingface"
:
if
self
.
tokenizer_backend
==
"huggingface"
:
return
self
.
tokenizer
.
eos_token
return
self
.
tokenizer
.
eos_token
el
if
self
.
tokenizer_backend
==
"tiktoken"
:
if
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
decode
([
self
.
tokenizer
.
eot_token
])
return
self
.
tokenizer
.
decode
([
self
.
tokenizer
.
eot_token
])
else
:
else
:
eval_logger
.
warning
(
eval_logger
.
warning
(
...
@@ -354,7 +349,7 @@ class TemplateAPI(TemplateLM):
...
@@ -354,7 +349,7 @@ class TemplateAPI(TemplateLM):
return
None
return
None
@
cached_property
@
cached_property
def
prefix_token_id
(
self
)
->
Optional
[
int
]
:
def
prefix_token_id
(
self
)
->
int
|
None
:
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
return
None
return
None
else
:
else
:
...
@@ -364,24 +359,24 @@ class TemplateAPI(TemplateLM):
...
@@ -364,24 +359,24 @@ class TemplateAPI(TemplateLM):
if
self
.
tokenizer
.
bos_token_id
is
not
None
:
if
self
.
tokenizer
.
bos_token_id
is
not
None
:
return
self
.
tokenizer
.
bos_token_id
return
self
.
tokenizer
.
bos_token_id
return
self
.
tokenizer
.
eos_token_id
return
self
.
tokenizer
.
eos_token_id
else
:
return
self
.
tokenizer
.
eot_token
return
self
.
tokenizer
.
eot_token
def
tok_encode
(
def
tok_encode
(
self
,
self
,
string
:
str
,
string
:
str
,
left_truncate_len
:
int
=
None
,
left_truncate_len
:
int
|
None
=
None
,
add_special_tokens
:
bool
=
False
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
int
]
,
L
ist
[
str
]
]
:
)
->
l
ist
[
l
ist
[
int
]]
|
l
ist
[
int
]
|
l
ist
[
str
]:
if
self
.
tokenizer_backend
is
None
:
if
self
.
tokenizer_backend
is
None
:
return
[
string
]
return
[
string
]
el
if
self
.
tokenizer_backend
==
"huggingface"
:
if
self
.
tokenizer_backend
==
"huggingface"
:
# by default for CausalLM - false or self.add_bos_token is set
# by default for CausalLM - false or self.add_bos_token is set
if
not
add_special_tokens
:
if
not
add_special_tokens
:
add_special_tokens
=
False
or
self
.
add_bos_token
add_special_tokens
=
False
or
self
.
add_bos_token
encoding
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
int
]
]
=
self
.
tokenizer
(
encoding
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
int
]
=
self
.
tokenizer
(
string
,
string
,
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
truncation
=
truncation
,
truncation
=
truncation
,
...
@@ -404,20 +399,20 @@ class TemplateAPI(TemplateLM):
...
@@ -404,20 +399,20 @@ class TemplateAPI(TemplateLM):
encoding
=
self
.
tokenizer
.
encode_batch
(
string
)
encoding
=
self
.
tokenizer
.
encode_batch
(
string
)
return
encoding
return
encoding
def
decode_batch
(
self
,
tokens
:
L
ist
[
L
ist
[
int
]])
->
L
ist
[
str
]:
def
decode_batch
(
self
,
tokens
:
l
ist
[
l
ist
[
int
]])
->
l
ist
[
str
]
|
None
:
if
self
.
tokenizer_backend
==
"huggingface"
:
if
self
.
tokenizer_backend
==
"huggingface"
:
return
self
.
tokenizer
.
batch_decode
(
tokens
)
return
self
.
tokenizer
.
batch_decode
(
tokens
)
el
if
self
.
tokenizer_backend
==
"tiktoken"
:
if
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
decode_batch
(
tokens
)
return
self
.
tokenizer
.
decode_batch
(
tokens
)
def
model_call
(
def
model_call
(
self
,
self
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
str
]
,
L
ist
[
JsonChatStr
]
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
str
]
|
l
ist
[
JsonChatStr
],
*
,
*
,
generate
:
bool
=
True
,
generate
:
bool
=
True
,
gen_kwargs
:
Optional
[
Dict
]
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Optional
[
dict
]
:
)
->
dict
|
None
:
# !!! Copy: shared dict for each request, need new object !!!
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
gen_kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
try
:
try
:
...
@@ -441,7 +436,7 @@ class TemplateAPI(TemplateLM):
...
@@ -441,7 +436,7 @@ class TemplateAPI(TemplateLM):
response
.
raise_for_status
()
response
.
raise_for_status
()
return
response
.
json
()
return
response
.
json
()
except
RetryError
:
except
RetryError
:
eval_logger
.
e
rror
(
eval_logger
.
e
xception
(
"API request failed after multiple retries. Please check the API status."
"API request failed after multiple retries. Please check the API status."
)
)
return
None
return
None
...
@@ -450,14 +445,14 @@ class TemplateAPI(TemplateLM):
...
@@ -450,14 +445,14 @@ class TemplateAPI(TemplateLM):
self
,
self
,
session
:
ClientSession
,
session
:
ClientSession
,
sem
:
asyncio
.
Semaphore
,
sem
:
asyncio
.
Semaphore
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
str
]
,
L
ist
[
JsonChatStr
]
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
str
]
|
l
ist
[
JsonChatStr
],
*
,
*
,
generate
:
bool
=
True
,
generate
:
bool
=
True
,
cache_keys
:
list
=
None
,
cache_keys
:
list
|
None
=
None
,
ctxlens
:
Optional
[
L
ist
[
int
]
]
=
None
,
ctxlens
:
l
ist
[
int
]
|
None
=
None
,
gen_kwargs
:
Optional
[
Dict
]
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
L
ist
[
str
]
,
L
ist
[
T
uple
[
float
,
bool
]]
,
None
]
:
)
->
l
ist
[
str
]
|
l
ist
[
t
uple
[
float
,
bool
]]
|
None
:
# !!! Copy: shared dict for each request, need new object !!!
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
gen_kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
payload
=
self
.
_create_payload
(
payload
=
self
.
_create_payload
(
...
@@ -508,8 +503,8 @@ class TemplateAPI(TemplateLM):
...
@@ -508,8 +503,8 @@ class TemplateAPI(TemplateLM):
sem
.
release
()
sem
.
release
()
def
batch_loglikelihood_requests
(
def
batch_loglikelihood_requests
(
self
,
chunks
:
Iterable
[
L
ist
[
LogLikelihoodInputs
]]
self
,
chunks
:
Iterable
[
l
ist
[
LogLikelihoodInputs
]]
)
->
T
uple
[
L
ist
[
L
ist
[
int
]],
L
ist
[
int
],
L
ist
[
T
uple
[
str
,
str
]]]:
)
->
t
uple
[
l
ist
[
l
ist
[
int
]],
l
ist
[
int
],
l
ist
[
t
uple
[
str
,
str
]]]:
inputs
=
[]
inputs
=
[]
ctxlens
=
[]
ctxlens
=
[]
cache_keys
=
[]
cache_keys
=
[]
...
@@ -536,9 +531,9 @@ class TemplateAPI(TemplateLM):
...
@@ -536,9 +531,9 @@ class TemplateAPI(TemplateLM):
cache_keys
:
list
,
cache_keys
:
list
,
*
,
*
,
generate
:
bool
=
True
,
generate
:
bool
=
True
,
ctxlens
:
L
ist
[
int
]
=
None
,
ctxlens
:
l
ist
[
int
]
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
L
ist
[
L
ist
[
str
]]
,
L
ist
[
L
ist
[
T
uple
[
float
,
bool
]]]
]
:
)
->
l
ist
[
l
ist
[
str
]]
|
l
ist
[
l
ist
[
t
uple
[
float
,
bool
]]]:
ctxlens
=
ctxlens
if
ctxlens
else
[
None
]
*
len
(
requests
)
ctxlens
=
ctxlens
if
ctxlens
else
[
None
]
*
len
(
requests
)
conn
=
TCPConnector
(
limit
=
self
.
_concurrent
,
ssl
=
self
.
verify_certificate
)
conn
=
TCPConnector
(
limit
=
self
.
_concurrent
,
ssl
=
self
.
verify_certificate
)
sem
=
asyncio
.
Semaphore
(
self
.
_concurrent
)
sem
=
asyncio
.
Semaphore
(
self
.
_concurrent
)
...
@@ -575,14 +570,14 @@ class TemplateAPI(TemplateLM):
...
@@ -575,14 +570,14 @@ class TemplateAPI(TemplateLM):
return
await
tqdm_asyncio
.
gather
(
*
tasks
,
desc
=
"Requesting API"
)
return
await
tqdm_asyncio
.
gather
(
*
tasks
,
desc
=
"Requesting API"
)
def
_loglikelihood_tokens
(
self
,
requests
,
**
kwargs
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
def
_loglikelihood_tokens
(
self
,
requests
,
**
kwargs
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
assert
self
.
tokenizer
is
not
None
,
(
assert
self
.
tokenizer
is
not
None
,
(
"Tokenizer is required for loglikelihood tasks to compute context lengths."
"Tokenizer is required for loglikelihood tasks to compute context lengths."
)
)
res
=
[]
res
=
[]
def
_collate
(
req
:
LogLikelihoodInputs
):
def
_collate
(
req
:
LogLikelihoodInputs
):
"""Defines the key for the sorted method"""
"""Defines the key for the sorted method
.
"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# - to know the size of a batch when going through the list, you know the first one is always the batch
...
@@ -639,8 +634,8 @@ class TemplateAPI(TemplateLM):
...
@@ -639,8 +634,8 @@ class TemplateAPI(TemplateLM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
generate_until
(
def
generate_until
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
str
]:
)
->
l
ist
[
str
]:
res
=
[]
res
=
[]
def
_collate_gen
(
_requests
):
def
_collate_gen
(
_requests
):
...
@@ -773,8 +768,8 @@ class TemplateAPI(TemplateLM):
...
@@ -773,8 +768,8 @@ class TemplateAPI(TemplateLM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
loglikelihood_rolling
(
def
loglikelihood_rolling
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
loglikelihoods
=
[]
loglikelihoods
=
[]
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
],
disable
=
disable_tqdm
):
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
],
disable
=
disable_tqdm
):
...
...
lm_eval/models/huggingface.py
View file @
223b9488
...
@@ -682,8 +682,7 @@ class HFLM(TemplateLM):
...
@@ -682,8 +682,7 @@ class HFLM(TemplateLM):
)
)
if
peft
:
if
peft
:
from
peft
import
PeftModel
from
peft
import
PeftModel
,
__version__
as
PEFT_VERSION
from
peft
import
__version__
as
PEFT_VERSION
if
model_kwargs
.
get
(
"load_in_4bit"
)
and
vparse
(
PEFT_VERSION
)
<
vparse
(
if
model_kwargs
.
get
(
"load_in_4bit"
)
and
vparse
(
PEFT_VERSION
)
<
vparse
(
"0.4.0"
"0.4.0"
...
...
lm_eval/models/openai_completions.py
View file @
223b9488
from
__future__
import
annotations
import
logging
import
logging
import
os
import
os
from
functools
import
cached_property
from
functools
import
cached_property
from
operator
import
itemgetter
from
operator
import
itemgetter
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.api_models
import
TemplateAPI
from
lm_eval.models.api_models
import
TemplateAPI
...
@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI):
...
@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI):
def
_create_payload
(
def
_create_payload
(
self
,
self
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
dict
]
,
L
ist
[
str
]
,
str
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
dict
]
|
l
ist
[
str
]
|
str
,
generate
=
False
,
generate
=
False
,
gen_kwargs
:
Optional
[
dict
]
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
seed
:
int
=
1234
,
seed
:
int
=
1234
,
eos
=
None
,
eos
=
None
,
**
kwargs
,
**
kwargs
,
...
@@ -50,24 +52,23 @@ class LocalCompletionsAPI(TemplateAPI):
...
@@ -50,24 +52,23 @@ class LocalCompletionsAPI(TemplateAPI):
"seed"
:
seed
,
"seed"
:
seed
,
**
gen_kwargs
,
**
gen_kwargs
,
}
}
else
:
return
{
return
{
"model"
:
self
.
model
,
"model"
:
self
.
model
,
"prompt"
:
messages
,
"prompt"
:
messages
,
"temperature"
:
0
,
"temperature"
:
0
,
"max_tokens"
:
1
,
"max_tokens"
:
1
,
"logprobs"
:
1
,
"logprobs"
:
1
,
"seed"
:
seed
,
"seed"
:
seed
,
"echo"
:
True
,
"echo"
:
True
,
}
}
@
staticmethod
@
staticmethod
def
parse_logprobs
(
def
parse_logprobs
(
outputs
:
Union
[
Dict
,
L
ist
[
D
ict
]
]
,
outputs
:
dict
|
l
ist
[
d
ict
],
tokens
:
L
ist
[
L
ist
[
int
]]
=
None
,
tokens
:
l
ist
[
l
ist
[
int
]]
=
None
,
ctxlens
:
L
ist
[
int
]
=
None
,
ctxlens
:
l
ist
[
int
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
res
=
[]
res
=
[]
if
not
isinstance
(
outputs
,
list
):
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
...
@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI):
...
@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI):
return
res
return
res
@
staticmethod
@
staticmethod
def
parse_generations
(
outputs
:
Union
[
Dict
,
L
ist
[
D
ict
]
]
,
**
kwargs
)
->
L
ist
[
str
]:
def
parse_generations
(
outputs
:
dict
|
l
ist
[
d
ict
],
**
kwargs
)
->
l
ist
[
str
]:
res
=
[]
res
=
[]
if
not
isinstance
(
outputs
,
list
):
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
...
@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
...
@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
def
_create_payload
(
def
_create_payload
(
self
,
self
,
messages
:
L
ist
[
D
ict
],
messages
:
l
ist
[
d
ict
],
generate
=
False
,
generate
=
False
,
gen_kwargs
:
dict
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
seed
=
1234
,
seed
=
1234
,
eos
=
None
,
eos
=
None
,
**
kwargs
,
**
kwargs
,
...
@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
...
@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
}
}
@
staticmethod
@
staticmethod
def
parse_generations
(
outputs
:
Union
[
Dict
,
L
ist
[
D
ict
]
]
,
**
kwargs
)
->
L
ist
[
str
]:
def
parse_generations
(
outputs
:
dict
|
l
ist
[
d
ict
],
**
kwargs
)
->
l
ist
[
str
]:
res
=
[]
res
=
[]
if
not
isinstance
(
outputs
,
list
):
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
...
@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI):
...
@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI):
def
tok_encode
(
def
tok_encode
(
self
,
self
,
string
:
Union
[
str
,
Any
]
,
string
:
str
|
Any
,
left_truncate_len
=
None
,
left_truncate_len
=
None
,
add_special_tokens
=
None
,
add_special_tokens
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
L
ist
[
str
]
,
L
ist
[
int
]
,
Any
]
:
)
->
l
ist
[
str
]
|
l
ist
[
int
]
|
Any
:
return
string
return
string
def
loglikelihood
(
self
,
requests
,
**
kwargs
):
def
loglikelihood
(
self
,
requests
,
**
kwargs
):
...
@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
...
@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
)
)
return
super
().
loglikelihood
(
requests
,
**
kwargs
)
return
super
().
loglikelihood
(
requests
,
**
kwargs
)
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]
:
def
chat_template
(
self
,
chat_template
:
bool
|
str
=
False
)
->
str
|
None
:
return
""
return
""
...
@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
...
@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
def
_create_payload
(
def
_create_payload
(
self
,
self
,
messages
:
L
ist
[
D
ict
],
messages
:
l
ist
[
d
ict
],
generate
=
False
,
generate
=
False
,
gen_kwargs
:
dict
=
None
,
gen_kwargs
:
dict
=
None
,
seed
=
1234
,
seed
=
1234
,
...
...
lm_eval/models/vllm_causallms.py
View file @
223b9488
from
__future__
import
annotations
import
copy
import
copy
import
gc
import
gc
import
logging
import
logging
...
@@ -7,7 +9,7 @@ from importlib.util import find_spec
...
@@ -7,7 +9,7 @@ from importlib.util import find_spec
from
multiprocessing
import
Process
,
Queue
from
multiprocessing
import
Process
,
Queue
from
queue
import
Empty
from
queue
import
Empty
from
time
import
sleep
from
time
import
sleep
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
import
jinja2
import
jinja2
from
more_itertools
import
distribute
from
more_itertools
import
distribute
...
@@ -113,30 +115,30 @@ class VLLM(TemplateLM):
...
@@ -113,30 +115,30 @@ class VLLM(TemplateLM):
self
,
self
,
pretrained
:
str
,
pretrained
:
str
,
dtype
:
Literal
[
"float16"
,
"bfloat16"
,
"float32"
,
"auto"
]
=
"auto"
,
dtype
:
Literal
[
"float16"
,
"bfloat16"
,
"float32"
,
"auto"
]
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
revision
:
str
|
None
=
None
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
trust_remote_code
:
bool
|
None
=
False
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer
:
str
|
None
=
None
,
tokenizer_mode
:
Literal
[
"auto"
,
"slow"
]
=
"auto"
,
tokenizer_mode
:
Literal
[
"auto"
,
"slow"
]
=
"auto"
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
tokenizer_revision
:
str
|
None
=
None
,
add_bos_token
:
Optional
[
bool
]
=
False
,
add_bos_token
:
bool
|
None
=
False
,
prefix_token_id
:
Optional
[
int
]
=
None
,
prefix_token_id
:
int
|
None
=
None
,
tensor_parallel_size
:
int
=
1
,
tensor_parallel_size
:
int
=
1
,
quantization
:
Optional
[
str
]
=
None
,
quantization
:
str
|
None
=
None
,
max_gen_toks
:
int
=
256
,
max_gen_toks
:
int
=
256
,
swap_space
:
int
=
4
,
swap_space
:
int
=
4
,
batch_size
:
Union
[
str
,
int
]
=
1
,
batch_size
:
str
|
int
=
1
,
max_batch_size
=
None
,
max_batch_size
:
int
|
None
=
None
,
max_length
:
int
=
None
,
max_length
:
int
|
None
=
None
,
max_model_len
:
int
=
None
,
max_model_len
:
int
|
None
=
None
,
seed
:
int
=
1234
,
seed
:
int
=
1234
,
gpu_memory_utilization
:
float
=
0.9
,
gpu_memory_utilization
:
float
=
0.9
,
data_parallel_size
:
int
=
1
,
data_parallel_size
:
int
=
1
,
lora_local_path
:
str
=
None
,
lora_local_path
:
str
|
None
=
None
,
# VLLM: enable thinking tags in the prompt.
# VLLM: enable thinking tags in the prompt.
enable_thinking
:
bool
=
True
,
enable_thinking
:
bool
=
True
,
chat_template_args
:
Optional
[
dict
]
=
None
,
chat_template_args
:
Optional
[
dict
]
=
None
,
# End marker for thinking tags - splits to get response after this token (if provided).
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token
:
Optional
[
str
]
=
None
,
think_end_token
:
str
|
None
=
None
,
max_lora_rank
:
int
=
16
,
max_lora_rank
:
int
=
16
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -172,7 +174,7 @@ class VLLM(TemplateLM):
...
@@ -172,7 +174,7 @@ class VLLM(TemplateLM):
"swap_space"
:
int
(
swap_space
),
"swap_space"
:
int
(
swap_space
),
"quantization"
:
quantization
,
"quantization"
:
quantization
,
"seed"
:
int
(
seed
),
"seed"
:
int
(
seed
),
"enable_lora"
:
True
if
lora_local_path
else
False
,
"enable_lora"
:
bool
(
lora_local_path
)
,
"max_lora_rank"
:
int
(
max_lora_rank
),
"max_lora_rank"
:
int
(
max_lora_rank
),
}
}
self
.
model_args
.
update
(
kwargs
)
self
.
model_args
.
update
(
kwargs
)
...
@@ -300,7 +302,7 @@ class VLLM(TemplateLM):
...
@@ -300,7 +302,7 @@ class VLLM(TemplateLM):
return
self
.
_max_gen_toks
return
self
.
_max_gen_toks
def
apply_chat_template
(
def
apply_chat_template
(
self
,
chat_history
:
L
ist
[
D
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
self
,
chat_history
:
l
ist
[
d
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
)
->
str
:
)
->
str
:
"""
"""
Method to apply a chat template to a list of chat history between user and model.
Method to apply a chat template to a list of chat history between user and model.
...
@@ -337,14 +339,14 @@ class VLLM(TemplateLM):
...
@@ -337,14 +339,14 @@ class VLLM(TemplateLM):
def
tok_encode
(
def
tok_encode
(
self
,
self
,
string
:
Union
[
str
,
L
ist
[
str
]
]
,
string
:
str
|
l
ist
[
str
],
left_truncate_len
:
int
=
None
,
left_truncate_len
:
int
=
None
,
add_special_tokens
:
bool
=
False
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
)
->
Union
[
L
ist
[
int
]
,
L
ist
[
L
ist
[
int
]]
]
:
)
->
l
ist
[
int
]
|
l
ist
[
l
ist
[
int
]]:
if
not
add_special_tokens
:
if
not
add_special_tokens
:
add_special_tokens
=
False
or
self
.
add_bos_token
add_special_tokens
=
False
or
self
.
add_bos_token
encoding
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
int
]
]
=
self
.
tokenizer
(
encoding
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
int
]
=
self
.
tokenizer
(
string
,
string
,
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
truncation
=
truncation
,
truncation
=
truncation
,
...
@@ -362,7 +364,7 @@ class VLLM(TemplateLM):
...
@@ -362,7 +364,7 @@ class VLLM(TemplateLM):
def
_model_generate
(
def
_model_generate
(
self
,
self
,
requests
:
L
ist
[
L
ist
[
int
]]
=
None
,
requests
:
l
ist
[
l
ist
[
int
]]
=
None
,
generate
:
bool
=
False
,
generate
:
bool
=
False
,
sampling_params
:
Union
[
List
[
"SamplingParams"
],
"SamplingParams"
,
None
]
=
None
,
sampling_params
:
Union
[
List
[
"SamplingParams"
],
"SamplingParams"
,
None
]
=
None
,
):
):
...
@@ -379,8 +381,8 @@ class VLLM(TemplateLM):
...
@@ -379,8 +381,8 @@ class VLLM(TemplateLM):
@
ray
.
remote
@
ray
.
remote
def
run_inference_one_model
(
def
run_inference_one_model
(
model_args
:
dict
,
model_args
:
dict
,
sampling_params
:
L
ist
[
"SamplingParams"
],
sampling_params
:
l
ist
[
"SamplingParams"
],
requests
:
L
ist
[
L
ist
[
int
]],
requests
:
l
ist
[
l
ist
[
int
]],
lora_request
:
"LoRARequest"
,
lora_request
:
"LoRARequest"
,
):
):
llm
=
LLM
(
**
model_args
)
llm
=
LLM
(
**
model_args
)
...
@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
...
@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
if
dead_procs
:
if
dead_procs
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Worker processes
{
dead_procs
}
died unexpectedly"
f
"Worker processes
{
dead_procs
}
died unexpectedly"
)
)
from
None
continue
continue
results
=
[
rank_res
[
i
]
for
i
in
range
(
len
(
procs
))]
results
=
[
rank_res
[
i
]
for
i
in
range
(
len
(
procs
))]
...
@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
...
@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
outputs
=
self
.
model
.
generate
(
outputs
=
self
.
model
.
generate
(
[
TokensPrompt
(
prompt_token_ids
=
request
)
for
request
in
requests
],
[
TokensPrompt
(
prompt_token_ids
=
request
)
for
request
in
requests
],
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
if
self
.
batch_size
==
"auto"
else
False
,
use_tqdm
=
self
.
batch_size
==
"auto"
,
lora_request
=
self
.
lora_request
,
lora_request
=
self
.
lora_request
,
)
)
return
outputs
return
outputs
def
loglikelihood_rolling
(
def
loglikelihood_rolling
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
adaptive_batch_size
=
None
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
if
self
.
batch_size
==
"auto"
:
adaptive_batch_size
=
len
(
requests
)
adaptive_batch_size
=
len
(
requests
)
...
@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
...
@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
)
)
):
):
rolling_token_windows
:
L
ist
[
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]]
=
list
(
rolling_token_windows
:
l
ist
[
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]]
=
list
(
map
(
map
(
make_disjoint_window
,
make_disjoint_window
,
get_rolling_token_windows
(
get_rolling_token_windows
(
...
@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
...
@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
return
loglikelihoods
return
loglikelihoods
def
generate_until
(
def
generate_until
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
str
]:
)
->
l
ist
[
str
]:
res
=
[]
res
=
[]
# batch tokenize contexts
# batch tokenize contexts
context
,
all_gen_kwargs
=
zip
(
*
(
req
.
args
for
req
in
requests
))
context
,
all_gen_kwargs
=
zip
(
*
(
req
.
args
for
req
in
requests
))
context_encoding
:
L
ist
[
L
ist
[
int
]]
=
self
.
tok_encode
(
context_encoding
:
l
ist
[
l
ist
[
int
]]
=
self
.
tok_encode
(
context
,
add_special_tokens
=
self
.
add_bos_token
context
,
add_special_tokens
=
self
.
add_bos_token
)
)
requests
=
[
requests
=
[
...
@@ -638,7 +640,7 @@ class VLLM(TemplateLM):
...
@@ -638,7 +640,7 @@ class VLLM(TemplateLM):
)
)
# cache generations
# cache generations
for
output
,
context
in
zip
(
cont
,
context
):
for
output
,
context
_
in
zip
(
cont
,
context
):
generated_text
:
str
=
output
.
outputs
[
0
].
text
generated_text
:
str
=
output
.
outputs
[
0
].
text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
generated_text
=
postprocess_generated_text
(
generated_text
=
postprocess_generated_text
(
...
@@ -646,7 +648,7 @@ class VLLM(TemplateLM):
...
@@ -646,7 +648,7 @@ class VLLM(TemplateLM):
)
)
res
.
append
(
generated_text
)
res
.
append
(
generated_text
)
self
.
cache_hook
.
add_partial
(
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
context
,
gen_kwargs
),
generated_text
"generate_until"
,
(
context
_
,
gen_kwargs
),
generated_text
)
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
...
@@ -656,9 +658,9 @@ class VLLM(TemplateLM):
...
@@ -656,9 +658,9 @@ class VLLM(TemplateLM):
def
_loglikelihood_tokens
(
def
_loglikelihood_tokens
(
self
,
self
,
requests
:
L
ist
[
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]],
requests
:
l
ist
[
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]],
disable_tqdm
:
bool
=
False
,
disable_tqdm
:
bool
=
False
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
res
=
[]
res
=
[]
def
_collate
(
x
):
def
_collate
(
x
):
...
@@ -679,7 +681,7 @@ class VLLM(TemplateLM):
...
@@ -679,7 +681,7 @@ class VLLM(TemplateLM):
for
chunk
in
chunks
:
for
chunk
in
chunks
:
inputs
=
[]
inputs
=
[]
ctxlens
=
[]
ctxlens
=
[]
for
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
for
_
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
if
(
if
(
full_length
:
=
len
(
context_enc
+
continuation_enc
)
full_length
:
=
len
(
context_enc
+
continuation_enc
)
)
>
self
.
max_length
:
)
>
self
.
max_length
:
...
@@ -717,7 +719,7 @@ class VLLM(TemplateLM):
...
@@ -717,7 +719,7 @@ class VLLM(TemplateLM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
@
staticmethod
@
staticmethod
def
_parse_logprobs
(
tokens
:
L
ist
,
outputs
,
ctxlen
:
int
)
->
T
uple
[
float
,
bool
]:
def
_parse_logprobs
(
tokens
:
l
ist
,
outputs
,
ctxlen
:
int
)
->
t
uple
[
float
,
bool
]:
"""Process logprobs and tokens.
"""Process logprobs and tokens.
:param tokens: list
:param tokens: list
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment