Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
223b9488
"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "f30abd090a1d02377a1211a8c8f5b10deac0e763"
Commit
223b9488
authored
Jul 23, 2025
by
Baber
Browse files
types
parent
7cef4d38
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
132 additions
and
135 deletions
+132
-135
lm_eval/models/api_models.py
lm_eval/models/api_models.py
+67
-72
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+1
-2
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+26
-25
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+38
-36
No files found.
lm_eval/models/api_models.py
View file @
223b9488
from
__future__
import
annotations
import
abc
import
asyncio
import
copy
...
...
@@ -8,16 +10,9 @@ from functools import cached_property
from
typing
import
(
TYPE_CHECKING
,
Any
,
Awaitable
,
Callable
,
Dict
,
Iterable
,
List
,
Literal
,
NamedTuple
,
Optional
,
Tuple
,
Union
,
)
...
...
@@ -36,18 +31,21 @@ from importlib.util import find_spec
from
io
import
BytesIO
from
lm_eval
import
utils
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.model
import
TemplateLM
from
lm_eval.models.utils
import
Collator
,
chunks
,
configure_pad_token
if
TYPE_CHECKING
:
from
collections.abc
import
Awaitable
,
Iterable
from
PIL
import
Image
from
lm_eval.api.instance
import
Instance
eval_logger
=
logging
.
getLogger
(
__name__
)
LogLikelihoodInputs
=
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]
LogLikelihoodInputs
=
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]
# utility class to keep track of json encoded chats
...
...
@@ -58,9 +56,7 @@ class JsonChatStr(NamedTuple):
return
self
.
prompt
.
encode
(
encoding
)
def
create_image_prompt
(
imgs
:
list
[
"Image.Image"
],
chat
:
dict
,
fmt
:
str
=
"PNG"
)
->
dict
:
def
create_image_prompt
(
imgs
:
list
[
Image
.
Image
],
chat
:
dict
,
fmt
:
str
=
"PNG"
)
->
dict
:
"""
Parameters
...
...
@@ -109,33 +105,32 @@ class TemplateAPI(TemplateLM):
model
:
str
=
None
,
pretrained
:
str
=
None
,
# `model` takes precedence over `pretrained` when passed.
base_url
:
str
=
None
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer
:
str
|
None
=
None
,
# Loglikelihood tasks require a tokenizer to calculate context lengths,
# however the requests can be sent as a string if the API doesn't support token inputs.
# use tokenized_requests=False
tokenizer_backend
:
Optional
[
Literal
[
"tiktoken"
,
"huggingface"
,
"None"
,
"none"
]
]
=
"huggingface"
,
tokenizer_backend
:
Literal
[
"tiktoken"
,
"huggingface"
,
"None"
,
"none"
]
|
None
=
"huggingface"
,
truncate
:
bool
=
False
,
# number of concurrent requests. More useful if not batching
num_concurrent
:
int
=
1
,
max_retries
:
int
=
3
,
max_gen_toks
:
int
=
256
,
batch_size
:
Union
[
str
,
int
]
=
1
,
batch_size
:
str
|
int
=
1
,
seed
:
int
=
1234
,
max_length
:
Optional
[
int
]
=
2048
,
max_length
:
int
|
None
=
2048
,
add_bos_token
:
bool
=
False
,
custom_prefix_token_id
:
int
=
None
,
# send the requests as tokens or strings
tokenized_requests
:
bool
=
True
,
trust_remote_code
:
bool
=
False
,
revision
:
Optional
[
str
]
=
"main"
,
revision
:
str
|
None
=
"main"
,
use_fast_tokenizer
:
bool
=
True
,
verify_certificate
:
bool
=
True
,
eos_string
:
str
=
None
,
# timeout in seconds
timeout
:
int
=
300
,
header
:
Optional
[
D
ict
[
str
,
str
]
]
=
None
,
header
:
d
ict
[
str
,
str
]
|
None
=
None
,
max_images
:
int
=
1
,
**
kwargs
,
)
->
None
:
...
...
@@ -232,12 +227,12 @@ class TemplateAPI(TemplateLM):
@
abc
.
abstractmethod
def
_create_payload
(
self
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
dict
]
,
L
ist
[
str
]
,
str
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
dict
]
|
l
ist
[
str
]
|
str
,
*
,
generate
:
bool
=
True
,
gen_kwargs
:
Optional
[
dict
]
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
seed
:
int
=
1234
,
eos
:
str
=
None
,
eos
:
str
|
None
=
None
,
**
kwargs
,
)
->
dict
:
"""This method is responsible for creating the json payload that will be sent to the API."""
...
...
@@ -245,9 +240,9 @@ class TemplateAPI(TemplateLM):
def
create_message
(
self
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
str
]
,
L
ist
[
JsonChatStr
]
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
str
]
|
l
ist
[
JsonChatStr
],
generate
=
False
,
)
->
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
dict
]
,
L
ist
[
str
]
,
str
]
:
)
->
l
ist
[
l
ist
[
int
]]
|
l
ist
[
dict
]
|
l
ist
[
str
]
|
str
:
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if
isinstance
(
messages
[
0
],
JsonChatStr
):
# for chat completions we need to decode the json string to list[dict,...]
...
...
@@ -276,17 +271,17 @@ class TemplateAPI(TemplateLM):
@
staticmethod
@
abc
.
abstractmethod
def
parse_logprobs
(
outputs
:
Union
[
Any
,
L
ist
[
Any
]
]
,
tokens
:
L
ist
[
L
ist
[
int
]]
=
None
,
ctxlen
:
L
ist
[
int
]
=
None
,
outputs
:
Any
|
l
ist
[
Any
],
tokens
:
l
ist
[
l
ist
[
int
]]
|
None
=
None
,
ctxlen
:
l
ist
[
int
]
|
None
=
None
,
**
kwargs
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
raise
NotImplementedError
@
staticmethod
@
abc
.
abstractmethod
def
parse_generations
(
outputs
:
Union
[
Any
,
L
ist
[
Any
]
]
,
**
kwargs
)
->
L
ist
[
str
]:
def
parse_generations
(
outputs
:
Any
|
l
ist
[
Any
],
**
kwargs
)
->
l
ist
[
str
]:
"""Method used to parse the generations from the (batched) API response. This method should return a list of str"""
raise
NotImplementedError
...
...
@@ -303,14 +298,15 @@ class TemplateAPI(TemplateLM):
@
property
def
tokenizer_name
(
self
)
->
str
:
"""Must be defined for LM subclasses which implement Chat Templating.
Should return the name of the tokenizer or chat template used.
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
"""
return
""
def
apply_chat_template
(
self
,
chat_history
:
L
ist
[
D
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
)
->
Union
[
str
,
JsonChatStr
]
:
self
,
chat_history
:
l
ist
[
d
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
)
->
str
|
JsonChatStr
:
"""Applies a chat template to a list of chat history between user and model."""
if
self
.
tokenizer_backend
==
"huggingface"
and
self
.
tokenized_requests
:
return
self
.
tokenizer
.
apply_chat_template
(
...
...
@@ -319,33 +315,32 @@ class TemplateAPI(TemplateLM):
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
not
add_generation_prompt
,
)
else
:
# bit of a hack. We'll load back before sending to the API
return
JsonChatStr
(
json
.
dumps
(
[{
**
item
,
"type"
:
"text"
}
for
item
in
chat_history
],
ensure_ascii
=
False
,
)
# bit of a hack. We'll load back before sending to the API
return
JsonChatStr
(
json
.
dumps
(
[{
**
item
,
"type"
:
"text"
}
for
item
in
chat_history
],
ensure_ascii
=
False
,
)
)
@
cached_property
def
eot_token_id
(
self
)
->
Optional
[
int
]
:
def
eot_token_id
(
self
)
->
int
|
None
:
if
self
.
tokenizer
is
None
:
return
None
else
:
if
self
.
tokenizer_backend
==
"huggingface"
:
return
self
.
tokenizer
.
eos_token_id
el
if
self
.
tokenizer_backend
==
"tiktoken"
:
if
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
eot_token
@
cached_property
def
eos_string
(
self
)
->
Optional
[
str
]
:
def
eos_string
(
self
)
->
str
|
None
:
if
self
.
_eos_string
:
return
self
.
_eos_string
el
if
self
.
tokenizer
is
not
None
:
if
self
.
tokenizer
is
not
None
:
if
self
.
tokenizer_backend
==
"huggingface"
:
return
self
.
tokenizer
.
eos_token
el
if
self
.
tokenizer_backend
==
"tiktoken"
:
if
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
decode
([
self
.
tokenizer
.
eot_token
])
else
:
eval_logger
.
warning
(
...
...
@@ -354,7 +349,7 @@ class TemplateAPI(TemplateLM):
return
None
@
cached_property
def
prefix_token_id
(
self
)
->
Optional
[
int
]
:
def
prefix_token_id
(
self
)
->
int
|
None
:
if
self
.
tokenizer
is
None
:
return
None
else
:
...
...
@@ -364,24 +359,24 @@ class TemplateAPI(TemplateLM):
if
self
.
tokenizer
.
bos_token_id
is
not
None
:
return
self
.
tokenizer
.
bos_token_id
return
self
.
tokenizer
.
eos_token_id
else
:
return
self
.
tokenizer
.
eot_token
return
self
.
tokenizer
.
eot_token
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
:
int
=
None
,
left_truncate_len
:
int
|
None
=
None
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
**
kwargs
,
)
->
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
int
]
,
L
ist
[
str
]
]
:
)
->
l
ist
[
l
ist
[
int
]]
|
l
ist
[
int
]
|
l
ist
[
str
]:
if
self
.
tokenizer_backend
is
None
:
return
[
string
]
el
if
self
.
tokenizer_backend
==
"huggingface"
:
if
self
.
tokenizer_backend
==
"huggingface"
:
# by default for CausalLM - false or self.add_bos_token is set
if
not
add_special_tokens
:
add_special_tokens
=
False
or
self
.
add_bos_token
encoding
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
int
]
]
=
self
.
tokenizer
(
encoding
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
int
]
=
self
.
tokenizer
(
string
,
add_special_tokens
=
add_special_tokens
,
truncation
=
truncation
,
...
...
@@ -404,20 +399,20 @@ class TemplateAPI(TemplateLM):
encoding
=
self
.
tokenizer
.
encode_batch
(
string
)
return
encoding
def
decode_batch
(
self
,
tokens
:
L
ist
[
L
ist
[
int
]])
->
L
ist
[
str
]:
def
decode_batch
(
self
,
tokens
:
l
ist
[
l
ist
[
int
]])
->
l
ist
[
str
]
|
None
:
if
self
.
tokenizer_backend
==
"huggingface"
:
return
self
.
tokenizer
.
batch_decode
(
tokens
)
el
if
self
.
tokenizer_backend
==
"tiktoken"
:
if
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
decode_batch
(
tokens
)
def
model_call
(
self
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
str
]
,
L
ist
[
JsonChatStr
]
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
str
]
|
l
ist
[
JsonChatStr
],
*
,
generate
:
bool
=
True
,
gen_kwargs
:
Optional
[
Dict
]
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
**
kwargs
,
)
->
Optional
[
dict
]
:
)
->
dict
|
None
:
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
try
:
...
...
@@ -441,7 +436,7 @@ class TemplateAPI(TemplateLM):
response
.
raise_for_status
()
return
response
.
json
()
except
RetryError
:
eval_logger
.
e
rror
(
eval_logger
.
e
xception
(
"API request failed after multiple retries. Please check the API status."
)
return
None
...
...
@@ -450,14 +445,14 @@ class TemplateAPI(TemplateLM):
self
,
session
:
ClientSession
,
sem
:
asyncio
.
Semaphore
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
str
]
,
L
ist
[
JsonChatStr
]
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
str
]
|
l
ist
[
JsonChatStr
],
*
,
generate
:
bool
=
True
,
cache_keys
:
list
=
None
,
ctxlens
:
Optional
[
L
ist
[
int
]
]
=
None
,
gen_kwargs
:
Optional
[
Dict
]
=
None
,
cache_keys
:
list
|
None
=
None
,
ctxlens
:
l
ist
[
int
]
|
None
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
**
kwargs
,
)
->
Union
[
L
ist
[
str
]
,
L
ist
[
T
uple
[
float
,
bool
]]
,
None
]
:
)
->
l
ist
[
str
]
|
l
ist
[
t
uple
[
float
,
bool
]]
|
None
:
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
payload
=
self
.
_create_payload
(
...
...
@@ -508,8 +503,8 @@ class TemplateAPI(TemplateLM):
sem
.
release
()
def
batch_loglikelihood_requests
(
self
,
chunks
:
Iterable
[
L
ist
[
LogLikelihoodInputs
]]
)
->
T
uple
[
L
ist
[
L
ist
[
int
]],
L
ist
[
int
],
L
ist
[
T
uple
[
str
,
str
]]]:
self
,
chunks
:
Iterable
[
l
ist
[
LogLikelihoodInputs
]]
)
->
t
uple
[
l
ist
[
l
ist
[
int
]],
l
ist
[
int
],
l
ist
[
t
uple
[
str
,
str
]]]:
inputs
=
[]
ctxlens
=
[]
cache_keys
=
[]
...
...
@@ -536,9 +531,9 @@ class TemplateAPI(TemplateLM):
cache_keys
:
list
,
*
,
generate
:
bool
=
True
,
ctxlens
:
L
ist
[
int
]
=
None
,
ctxlens
:
l
ist
[
int
]
|
None
=
None
,
**
kwargs
,
)
->
Union
[
L
ist
[
L
ist
[
str
]]
,
L
ist
[
L
ist
[
T
uple
[
float
,
bool
]]]
]
:
)
->
l
ist
[
l
ist
[
str
]]
|
l
ist
[
l
ist
[
t
uple
[
float
,
bool
]]]:
ctxlens
=
ctxlens
if
ctxlens
else
[
None
]
*
len
(
requests
)
conn
=
TCPConnector
(
limit
=
self
.
_concurrent
,
ssl
=
self
.
verify_certificate
)
sem
=
asyncio
.
Semaphore
(
self
.
_concurrent
)
...
...
@@ -575,14 +570,14 @@ class TemplateAPI(TemplateLM):
return
await
tqdm_asyncio
.
gather
(
*
tasks
,
desc
=
"Requesting API"
)
def
_loglikelihood_tokens
(
self
,
requests
,
**
kwargs
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
def
_loglikelihood_tokens
(
self
,
requests
,
**
kwargs
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
assert
self
.
tokenizer
is
not
None
,
(
"Tokenizer is required for loglikelihood tasks to compute context lengths."
)
res
=
[]
def
_collate
(
req
:
LogLikelihoodInputs
):
"""Defines the key for the sorted method"""
"""Defines the key for the sorted method
.
"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
...
...
@@ -639,8 +634,8 @@ class TemplateAPI(TemplateLM):
return
re_ord
.
get_original
(
res
)
def
generate_until
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
str
]:
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
l
ist
[
str
]:
res
=
[]
def
_collate_gen
(
_requests
):
...
...
@@ -773,8 +768,8 @@ class TemplateAPI(TemplateLM):
return
re_ord
.
get_original
(
res
)
def
loglikelihood_rolling
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
float
]:
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
l
ist
[
float
]:
loglikelihoods
=
[]
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
],
disable
=
disable_tqdm
):
...
...
lm_eval/models/huggingface.py
View file @
223b9488
...
...
@@ -682,8 +682,7 @@ class HFLM(TemplateLM):
)
if
peft
:
from
peft
import
PeftModel
from
peft
import
__version__
as
PEFT_VERSION
from
peft
import
PeftModel
,
__version__
as
PEFT_VERSION
if
model_kwargs
.
get
(
"load_in_4bit"
)
and
vparse
(
PEFT_VERSION
)
<
vparse
(
"0.4.0"
...
...
lm_eval/models/openai_completions.py
View file @
223b9488
from
__future__
import
annotations
import
logging
import
os
from
functools
import
cached_property
from
operator
import
itemgetter
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.api_models
import
TemplateAPI
...
...
@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI):
def
_create_payload
(
self
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
dict
]
,
L
ist
[
str
]
,
str
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
dict
]
|
l
ist
[
str
]
|
str
,
generate
=
False
,
gen_kwargs
:
Optional
[
dict
]
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
seed
:
int
=
1234
,
eos
=
None
,
**
kwargs
,
...
...
@@ -50,24 +52,23 @@ class LocalCompletionsAPI(TemplateAPI):
"seed"
:
seed
,
**
gen_kwargs
,
}
else
:
return
{
"model"
:
self
.
model
,
"prompt"
:
messages
,
"temperature"
:
0
,
"max_tokens"
:
1
,
"logprobs"
:
1
,
"seed"
:
seed
,
"echo"
:
True
,
}
return
{
"model"
:
self
.
model
,
"prompt"
:
messages
,
"temperature"
:
0
,
"max_tokens"
:
1
,
"logprobs"
:
1
,
"seed"
:
seed
,
"echo"
:
True
,
}
@
staticmethod
def
parse_logprobs
(
outputs
:
Union
[
Dict
,
L
ist
[
D
ict
]
]
,
tokens
:
L
ist
[
L
ist
[
int
]]
=
None
,
ctxlens
:
L
ist
[
int
]
=
None
,
outputs
:
dict
|
l
ist
[
d
ict
],
tokens
:
l
ist
[
l
ist
[
int
]]
=
None
,
ctxlens
:
l
ist
[
int
]
=
None
,
**
kwargs
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
res
=
[]
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
...
...
@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI):
return
res
@
staticmethod
def
parse_generations
(
outputs
:
Union
[
Dict
,
L
ist
[
D
ict
]
]
,
**
kwargs
)
->
L
ist
[
str
]:
def
parse_generations
(
outputs
:
dict
|
l
ist
[
d
ict
],
**
kwargs
)
->
l
ist
[
str
]:
res
=
[]
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
...
...
@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
def
_create_payload
(
self
,
messages
:
L
ist
[
D
ict
],
messages
:
l
ist
[
d
ict
],
generate
=
False
,
gen_kwargs
:
dict
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
seed
=
1234
,
eos
=
None
,
**
kwargs
,
...
...
@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
}
@
staticmethod
def
parse_generations
(
outputs
:
Union
[
Dict
,
L
ist
[
D
ict
]
]
,
**
kwargs
)
->
L
ist
[
str
]:
def
parse_generations
(
outputs
:
dict
|
l
ist
[
d
ict
],
**
kwargs
)
->
l
ist
[
str
]:
res
=
[]
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
...
...
@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI):
def
tok_encode
(
self
,
string
:
Union
[
str
,
Any
]
,
string
:
str
|
Any
,
left_truncate_len
=
None
,
add_special_tokens
=
None
,
**
kwargs
,
)
->
Union
[
L
ist
[
str
]
,
L
ist
[
int
]
,
Any
]
:
)
->
l
ist
[
str
]
|
l
ist
[
int
]
|
Any
:
return
string
def
loglikelihood
(
self
,
requests
,
**
kwargs
):
...
...
@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
)
return
super
().
loglikelihood
(
requests
,
**
kwargs
)
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]
:
def
chat_template
(
self
,
chat_template
:
bool
|
str
=
False
)
->
str
|
None
:
return
""
...
...
@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
def
_create_payload
(
self
,
messages
:
L
ist
[
D
ict
],
messages
:
l
ist
[
d
ict
],
generate
=
False
,
gen_kwargs
:
dict
=
None
,
seed
=
1234
,
...
...
lm_eval/models/vllm_causallms.py
View file @
223b9488
from
__future__
import
annotations
import
copy
import
gc
import
logging
...
...
@@ -7,7 +9,7 @@ from importlib.util import find_spec
from
multiprocessing
import
Process
,
Queue
from
queue
import
Empty
from
time
import
sleep
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
import
jinja2
from
more_itertools
import
distribute
...
...
@@ -113,30 +115,30 @@ class VLLM(TemplateLM):
self
,
pretrained
:
str
,
dtype
:
Literal
[
"float16"
,
"bfloat16"
,
"float32"
,
"auto"
]
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
tokenizer
:
Optional
[
str
]
=
None
,
revision
:
str
|
None
=
None
,
trust_remote_code
:
bool
|
None
=
False
,
tokenizer
:
str
|
None
=
None
,
tokenizer_mode
:
Literal
[
"auto"
,
"slow"
]
=
"auto"
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
add_bos_token
:
Optional
[
bool
]
=
False
,
prefix_token_id
:
Optional
[
int
]
=
None
,
tokenizer_revision
:
str
|
None
=
None
,
add_bos_token
:
bool
|
None
=
False
,
prefix_token_id
:
int
|
None
=
None
,
tensor_parallel_size
:
int
=
1
,
quantization
:
Optional
[
str
]
=
None
,
quantization
:
str
|
None
=
None
,
max_gen_toks
:
int
=
256
,
swap_space
:
int
=
4
,
batch_size
:
Union
[
str
,
int
]
=
1
,
max_batch_size
=
None
,
max_length
:
int
=
None
,
max_model_len
:
int
=
None
,
batch_size
:
str
|
int
=
1
,
max_batch_size
:
int
|
None
=
None
,
max_length
:
int
|
None
=
None
,
max_model_len
:
int
|
None
=
None
,
seed
:
int
=
1234
,
gpu_memory_utilization
:
float
=
0.9
,
data_parallel_size
:
int
=
1
,
lora_local_path
:
str
=
None
,
lora_local_path
:
str
|
None
=
None
,
# VLLM: enable thinking tags in the prompt.
enable_thinking
:
bool
=
True
,
chat_template_args
:
Optional
[
dict
]
=
None
,
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token
:
Optional
[
str
]
=
None
,
think_end_token
:
str
|
None
=
None
,
max_lora_rank
:
int
=
16
,
**
kwargs
,
):
...
...
@@ -172,7 +174,7 @@ class VLLM(TemplateLM):
"swap_space"
:
int
(
swap_space
),
"quantization"
:
quantization
,
"seed"
:
int
(
seed
),
"enable_lora"
:
True
if
lora_local_path
else
False
,
"enable_lora"
:
bool
(
lora_local_path
)
,
"max_lora_rank"
:
int
(
max_lora_rank
),
}
self
.
model_args
.
update
(
kwargs
)
...
...
@@ -300,7 +302,7 @@ class VLLM(TemplateLM):
return
self
.
_max_gen_toks
def
apply_chat_template
(
self
,
chat_history
:
L
ist
[
D
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
self
,
chat_history
:
l
ist
[
d
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
)
->
str
:
"""
Method to apply a chat template to a list of chat history between user and model.
...
...
@@ -337,14 +339,14 @@ class VLLM(TemplateLM):
def
tok_encode
(
self
,
string
:
Union
[
str
,
L
ist
[
str
]
]
,
string
:
str
|
l
ist
[
str
],
left_truncate_len
:
int
=
None
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
)
->
Union
[
L
ist
[
int
]
,
L
ist
[
L
ist
[
int
]]
]
:
)
->
l
ist
[
int
]
|
l
ist
[
l
ist
[
int
]]:
if
not
add_special_tokens
:
add_special_tokens
=
False
or
self
.
add_bos_token
encoding
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
int
]
]
=
self
.
tokenizer
(
encoding
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
int
]
=
self
.
tokenizer
(
string
,
add_special_tokens
=
add_special_tokens
,
truncation
=
truncation
,
...
...
@@ -362,7 +364,7 @@ class VLLM(TemplateLM):
def
_model_generate
(
self
,
requests
:
L
ist
[
L
ist
[
int
]]
=
None
,
requests
:
l
ist
[
l
ist
[
int
]]
=
None
,
generate
:
bool
=
False
,
sampling_params
:
Union
[
List
[
"SamplingParams"
],
"SamplingParams"
,
None
]
=
None
,
):
...
...
@@ -379,8 +381,8 @@ class VLLM(TemplateLM):
@
ray
.
remote
def
run_inference_one_model
(
model_args
:
dict
,
sampling_params
:
L
ist
[
"SamplingParams"
],
requests
:
L
ist
[
L
ist
[
int
]],
sampling_params
:
l
ist
[
"SamplingParams"
],
requests
:
l
ist
[
l
ist
[
int
]],
lora_request
:
"LoRARequest"
,
):
llm
=
LLM
(
**
model_args
)
...
...
@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
if
dead_procs
:
raise
RuntimeError
(
f
"Worker processes
{
dead_procs
}
died unexpectedly"
)
)
from
None
continue
results
=
[
rank_res
[
i
]
for
i
in
range
(
len
(
procs
))]
...
...
@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
outputs
=
self
.
model
.
generate
(
[
TokensPrompt
(
prompt_token_ids
=
request
)
for
request
in
requests
],
sampling_params
=
sampling_params
,
use_tqdm
=
True
if
self
.
batch_size
==
"auto"
else
False
,
use_tqdm
=
self
.
batch_size
==
"auto"
,
lora_request
=
self
.
lora_request
,
)
return
outputs
def
loglikelihood_rolling
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
float
]:
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
l
ist
[
float
]:
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
adaptive_batch_size
=
len
(
requests
)
...
...
@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
)
):
rolling_token_windows
:
L
ist
[
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]]
=
list
(
rolling_token_windows
:
l
ist
[
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]]
=
list
(
map
(
make_disjoint_window
,
get_rolling_token_windows
(
...
...
@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
return
loglikelihoods
def
generate_until
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
str
]:
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
l
ist
[
str
]:
res
=
[]
# batch tokenize contexts
context
,
all_gen_kwargs
=
zip
(
*
(
req
.
args
for
req
in
requests
))
context_encoding
:
L
ist
[
L
ist
[
int
]]
=
self
.
tok_encode
(
context_encoding
:
l
ist
[
l
ist
[
int
]]
=
self
.
tok_encode
(
context
,
add_special_tokens
=
self
.
add_bos_token
)
requests
=
[
...
...
@@ -638,7 +640,7 @@ class VLLM(TemplateLM):
)
# cache generations
for
output
,
context
in
zip
(
cont
,
context
):
for
output
,
context
_
in
zip
(
cont
,
context
):
generated_text
:
str
=
output
.
outputs
[
0
].
text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
generated_text
=
postprocess_generated_text
(
...
...
@@ -646,7 +648,7 @@ class VLLM(TemplateLM):
)
res
.
append
(
generated_text
)
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
context
,
gen_kwargs
),
generated_text
"generate_until"
,
(
context
_
,
gen_kwargs
),
generated_text
)
pbar
.
update
(
1
)
...
...
@@ -656,9 +658,9 @@ class VLLM(TemplateLM):
def
_loglikelihood_tokens
(
self
,
requests
:
L
ist
[
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]],
requests
:
l
ist
[
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]],
disable_tqdm
:
bool
=
False
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
res
=
[]
def
_collate
(
x
):
...
...
@@ -679,7 +681,7 @@ class VLLM(TemplateLM):
for
chunk
in
chunks
:
inputs
=
[]
ctxlens
=
[]
for
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
for
_
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
if
(
full_length
:
=
len
(
context_enc
+
continuation_enc
)
)
>
self
.
max_length
:
...
...
@@ -717,7 +719,7 @@ class VLLM(TemplateLM):
return
re_ord
.
get_original
(
res
)
@
staticmethod
def
_parse_logprobs
(
tokens
:
L
ist
,
outputs
,
ctxlen
:
int
)
->
T
uple
[
float
,
bool
]:
def
_parse_logprobs
(
tokens
:
l
ist
,
outputs
,
ctxlen
:
int
)
->
t
uple
[
float
,
bool
]:
"""Process logprobs and tokens.
:param tokens: list
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment