Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1f97a945
Commit
1f97a945
authored
Jul 23, 2025
by
Baber
Browse files
types
parent
0087929e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
273 additions
and
278 deletions
+273
-278
lm_eval/models/api_models.py
lm_eval/models/api_models.py
+67
-72
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+138
-141
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+26
-25
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+42
-40
No files found.
lm_eval/models/api_models.py
View file @
1f97a945
from
__future__
import
annotations
import
abc
import
abc
import
asyncio
import
asyncio
import
copy
import
copy
...
@@ -8,16 +10,9 @@ from functools import cached_property
...
@@ -8,16 +10,9 @@ from functools import cached_property
from
typing
import
(
from
typing
import
(
TYPE_CHECKING
,
TYPE_CHECKING
,
Any
,
Any
,
Awaitable
,
Callable
,
Callable
,
Dict
,
Iterable
,
List
,
Literal
,
Literal
,
NamedTuple
,
NamedTuple
,
Optional
,
Tuple
,
Union
,
)
)
...
@@ -36,18 +31,21 @@ from importlib.util import find_spec
...
@@ -36,18 +31,21 @@ from importlib.util import find_spec
from
io
import
BytesIO
from
io
import
BytesIO
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.model
import
TemplateLM
from
lm_eval.api.model
import
TemplateLM
from
lm_eval.models.utils
import
Collator
,
chunks
,
configure_pad_token
from
lm_eval.models.utils
import
Collator
,
chunks
,
configure_pad_token
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
collections.abc
import
Awaitable
,
Iterable
from
PIL
import
Image
from
PIL
import
Image
from
lm_eval.api.instance
import
Instance
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
LogLikelihoodInputs
=
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]
LogLikelihoodInputs
=
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]
# utility class to keep track of json encoded chats
# utility class to keep track of json encoded chats
...
@@ -58,9 +56,7 @@ class JsonChatStr(NamedTuple):
...
@@ -58,9 +56,7 @@ class JsonChatStr(NamedTuple):
return
self
.
prompt
.
encode
(
encoding
)
return
self
.
prompt
.
encode
(
encoding
)
def
create_image_prompt
(
def
create_image_prompt
(
imgs
:
list
[
Image
.
Image
],
chat
:
dict
,
fmt
:
str
=
"PNG"
)
->
dict
:
imgs
:
list
[
"Image.Image"
],
chat
:
dict
,
fmt
:
str
=
"PNG"
)
->
dict
:
"""
"""
Parameters
Parameters
...
@@ -109,33 +105,32 @@ class TemplateAPI(TemplateLM):
...
@@ -109,33 +105,32 @@ class TemplateAPI(TemplateLM):
model
:
str
=
None
,
model
:
str
=
None
,
pretrained
:
str
=
None
,
# `model` takes precedence over `pretrained` when passed.
pretrained
:
str
=
None
,
# `model` takes precedence over `pretrained` when passed.
base_url
:
str
=
None
,
base_url
:
str
=
None
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer
:
str
|
None
=
None
,
# Loglikelihood tasks require a tokenizer to calculate context lengths,
# Loglikelihood tasks require a tokenizer to calculate context lengths,
# however the requests can be sent as a string if the API doesn't support token inputs.
# however the requests can be sent as a string if the API doesn't support token inputs.
# use tokenized_requests=False
# use tokenized_requests=False
tokenizer_backend
:
Optional
[
tokenizer_backend
:
Literal
[
"tiktoken"
,
"huggingface"
,
"None"
,
"none"
]
Literal
[
"tiktoken"
,
"huggingface"
,
"None"
,
"none"
]
|
None
=
"huggingface"
,
]
=
"huggingface"
,
truncate
:
bool
=
False
,
truncate
:
bool
=
False
,
# number of concurrent requests. More useful if not batching
# number of concurrent requests. More useful if not batching
num_concurrent
:
int
=
1
,
num_concurrent
:
int
=
1
,
max_retries
:
int
=
3
,
max_retries
:
int
=
3
,
max_gen_toks
:
int
=
256
,
max_gen_toks
:
int
=
256
,
batch_size
:
Union
[
str
,
int
]
=
1
,
batch_size
:
str
|
int
=
1
,
seed
:
int
=
1234
,
seed
:
int
=
1234
,
max_length
:
Optional
[
int
]
=
2048
,
max_length
:
int
|
None
=
2048
,
add_bos_token
:
bool
=
False
,
add_bos_token
:
bool
=
False
,
custom_prefix_token_id
:
int
=
None
,
custom_prefix_token_id
:
int
=
None
,
# send the requests as tokens or strings
# send the requests as tokens or strings
tokenized_requests
:
bool
=
True
,
tokenized_requests
:
bool
=
True
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
revision
:
Optional
[
str
]
=
"main"
,
revision
:
str
|
None
=
"main"
,
use_fast_tokenizer
:
bool
=
True
,
use_fast_tokenizer
:
bool
=
True
,
verify_certificate
:
bool
=
True
,
verify_certificate
:
bool
=
True
,
eos_string
:
str
=
None
,
eos_string
:
str
=
None
,
# timeout in seconds
# timeout in seconds
timeout
:
int
=
300
,
timeout
:
int
=
300
,
header
:
Optional
[
D
ict
[
str
,
str
]
]
=
None
,
header
:
d
ict
[
str
,
str
]
|
None
=
None
,
max_images
:
int
=
1
,
max_images
:
int
=
1
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
...
@@ -232,12 +227,12 @@ class TemplateAPI(TemplateLM):
...
@@ -232,12 +227,12 @@ class TemplateAPI(TemplateLM):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
_create_payload
(
def
_create_payload
(
self
,
self
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
dict
]
,
L
ist
[
str
]
,
str
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
dict
]
|
l
ist
[
str
]
|
str
,
*
,
*
,
generate
:
bool
=
True
,
generate
:
bool
=
True
,
gen_kwargs
:
Optional
[
dict
]
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
seed
:
int
=
1234
,
seed
:
int
=
1234
,
eos
:
str
=
None
,
eos
:
str
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
dict
:
)
->
dict
:
"""This method is responsible for creating the json payload that will be sent to the API."""
"""This method is responsible for creating the json payload that will be sent to the API."""
...
@@ -245,9 +240,9 @@ class TemplateAPI(TemplateLM):
...
@@ -245,9 +240,9 @@ class TemplateAPI(TemplateLM):
def
create_message
(
def
create_message
(
self
,
self
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
str
]
,
L
ist
[
JsonChatStr
]
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
str
]
|
l
ist
[
JsonChatStr
],
generate
=
False
,
generate
=
False
,
)
->
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
dict
]
,
L
ist
[
str
]
,
str
]
:
)
->
l
ist
[
l
ist
[
int
]]
|
l
ist
[
dict
]
|
l
ist
[
str
]
|
str
:
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if
isinstance
(
messages
[
0
],
JsonChatStr
):
if
isinstance
(
messages
[
0
],
JsonChatStr
):
# for chat completions we need to decode the json string to list[dict,...]
# for chat completions we need to decode the json string to list[dict,...]
...
@@ -276,17 +271,17 @@ class TemplateAPI(TemplateLM):
...
@@ -276,17 +271,17 @@ class TemplateAPI(TemplateLM):
@
staticmethod
@
staticmethod
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
parse_logprobs
(
def
parse_logprobs
(
outputs
:
Union
[
Any
,
L
ist
[
Any
]
]
,
outputs
:
Any
|
l
ist
[
Any
],
tokens
:
L
ist
[
L
ist
[
int
]]
=
None
,
tokens
:
l
ist
[
l
ist
[
int
]]
|
None
=
None
,
ctxlen
:
L
ist
[
int
]
=
None
,
ctxlen
:
l
ist
[
int
]
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
@
staticmethod
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
parse_generations
(
outputs
:
Union
[
Any
,
L
ist
[
Any
]
]
,
**
kwargs
)
->
L
ist
[
str
]:
def
parse_generations
(
outputs
:
Any
|
l
ist
[
Any
],
**
kwargs
)
->
l
ist
[
str
]:
"""Method used to parse the generations from the (batched) API response. This method should return a list of str"""
"""Method used to parse the generations from the (batched) API response. This method should return a list of str"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -303,14 +298,15 @@ class TemplateAPI(TemplateLM):
...
@@ -303,14 +298,15 @@ class TemplateAPI(TemplateLM):
@
property
@
property
def
tokenizer_name
(
self
)
->
str
:
def
tokenizer_name
(
self
)
->
str
:
"""Must be defined for LM subclasses which implement Chat Templating.
"""Must be defined for LM subclasses which implement Chat Templating.
Should return the name of the tokenizer or chat template used.
Should return the name of the tokenizer or chat template used.
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
"""
"""
return
""
return
""
def
apply_chat_template
(
def
apply_chat_template
(
self
,
chat_history
:
L
ist
[
D
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
self
,
chat_history
:
l
ist
[
d
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
)
->
Union
[
str
,
JsonChatStr
]
:
)
->
str
|
JsonChatStr
:
"""Applies a chat template to a list of chat history between user and model."""
"""Applies a chat template to a list of chat history between user and model."""
if
self
.
tokenizer_backend
==
"huggingface"
and
self
.
tokenized_requests
:
if
self
.
tokenizer_backend
==
"huggingface"
and
self
.
tokenized_requests
:
return
self
.
tokenizer
.
apply_chat_template
(
return
self
.
tokenizer
.
apply_chat_template
(
...
@@ -319,7 +315,6 @@ class TemplateAPI(TemplateLM):
...
@@ -319,7 +315,6 @@ class TemplateAPI(TemplateLM):
add_generation_prompt
=
add_generation_prompt
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
not
add_generation_prompt
,
continue_final_message
=
not
add_generation_prompt
,
)
)
else
:
# bit of a hack. We'll load back before sending to the API
# bit of a hack. We'll load back before sending to the API
return
JsonChatStr
(
return
JsonChatStr
(
json
.
dumps
(
json
.
dumps
(
...
@@ -329,23 +324,23 @@ class TemplateAPI(TemplateLM):
...
@@ -329,23 +324,23 @@ class TemplateAPI(TemplateLM):
)
)
@
cached_property
@
cached_property
def
eot_token_id
(
self
)
->
Optional
[
int
]
:
def
eot_token_id
(
self
)
->
int
|
None
:
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
return
None
return
None
else
:
else
:
if
self
.
tokenizer_backend
==
"huggingface"
:
if
self
.
tokenizer_backend
==
"huggingface"
:
return
self
.
tokenizer
.
eos_token_id
return
self
.
tokenizer
.
eos_token_id
el
if
self
.
tokenizer_backend
==
"tiktoken"
:
if
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
eot_token
return
self
.
tokenizer
.
eot_token
@
cached_property
@
cached_property
def
eos_string
(
self
)
->
Optional
[
str
]
:
def
eos_string
(
self
)
->
str
|
None
:
if
self
.
_eos_string
:
if
self
.
_eos_string
:
return
self
.
_eos_string
return
self
.
_eos_string
el
if
self
.
tokenizer
is
not
None
:
if
self
.
tokenizer
is
not
None
:
if
self
.
tokenizer_backend
==
"huggingface"
:
if
self
.
tokenizer_backend
==
"huggingface"
:
return
self
.
tokenizer
.
eos_token
return
self
.
tokenizer
.
eos_token
el
if
self
.
tokenizer_backend
==
"tiktoken"
:
if
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
decode
([
self
.
tokenizer
.
eot_token
])
return
self
.
tokenizer
.
decode
([
self
.
tokenizer
.
eot_token
])
else
:
else
:
eval_logger
.
warning
(
eval_logger
.
warning
(
...
@@ -354,7 +349,7 @@ class TemplateAPI(TemplateLM):
...
@@ -354,7 +349,7 @@ class TemplateAPI(TemplateLM):
return
None
return
None
@
cached_property
@
cached_property
def
prefix_token_id
(
self
)
->
Optional
[
int
]
:
def
prefix_token_id
(
self
)
->
int
|
None
:
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
return
None
return
None
else
:
else
:
...
@@ -364,24 +359,24 @@ class TemplateAPI(TemplateLM):
...
@@ -364,24 +359,24 @@ class TemplateAPI(TemplateLM):
if
self
.
tokenizer
.
bos_token_id
is
not
None
:
if
self
.
tokenizer
.
bos_token_id
is
not
None
:
return
self
.
tokenizer
.
bos_token_id
return
self
.
tokenizer
.
bos_token_id
return
self
.
tokenizer
.
eos_token_id
return
self
.
tokenizer
.
eos_token_id
else
:
return
self
.
tokenizer
.
eot_token
return
self
.
tokenizer
.
eot_token
def
tok_encode
(
def
tok_encode
(
self
,
self
,
string
:
str
,
string
:
str
,
left_truncate_len
:
int
=
None
,
left_truncate_len
:
int
|
None
=
None
,
add_special_tokens
:
bool
=
False
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
int
]
,
L
ist
[
str
]
]
:
)
->
l
ist
[
l
ist
[
int
]]
|
l
ist
[
int
]
|
l
ist
[
str
]:
if
self
.
tokenizer_backend
is
None
:
if
self
.
tokenizer_backend
is
None
:
return
[
string
]
return
[
string
]
el
if
self
.
tokenizer_backend
==
"huggingface"
:
if
self
.
tokenizer_backend
==
"huggingface"
:
# by default for CausalLM - false or self.add_bos_token is set
# by default for CausalLM - false or self.add_bos_token is set
if
not
add_special_tokens
:
if
not
add_special_tokens
:
add_special_tokens
=
False
or
self
.
add_bos_token
add_special_tokens
=
False
or
self
.
add_bos_token
encoding
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
int
]
]
=
self
.
tokenizer
(
encoding
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
int
]
=
self
.
tokenizer
(
string
,
string
,
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
truncation
=
truncation
,
truncation
=
truncation
,
...
@@ -404,20 +399,20 @@ class TemplateAPI(TemplateLM):
...
@@ -404,20 +399,20 @@ class TemplateAPI(TemplateLM):
encoding
=
self
.
tokenizer
.
encode_batch
(
string
)
encoding
=
self
.
tokenizer
.
encode_batch
(
string
)
return
encoding
return
encoding
def
decode_batch
(
self
,
tokens
:
L
ist
[
L
ist
[
int
]])
->
L
ist
[
str
]:
def
decode_batch
(
self
,
tokens
:
l
ist
[
l
ist
[
int
]])
->
l
ist
[
str
]
|
None
:
if
self
.
tokenizer_backend
==
"huggingface"
:
if
self
.
tokenizer_backend
==
"huggingface"
:
return
self
.
tokenizer
.
batch_decode
(
tokens
)
return
self
.
tokenizer
.
batch_decode
(
tokens
)
el
if
self
.
tokenizer_backend
==
"tiktoken"
:
if
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
decode_batch
(
tokens
)
return
self
.
tokenizer
.
decode_batch
(
tokens
)
def
model_call
(
def
model_call
(
self
,
self
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
str
]
,
L
ist
[
JsonChatStr
]
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
str
]
|
l
ist
[
JsonChatStr
],
*
,
*
,
generate
:
bool
=
True
,
generate
:
bool
=
True
,
gen_kwargs
:
Optional
[
Dict
]
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Optional
[
dict
]
:
)
->
dict
|
None
:
# !!! Copy: shared dict for each request, need new object !!!
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
gen_kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
try
:
try
:
...
@@ -441,7 +436,7 @@ class TemplateAPI(TemplateLM):
...
@@ -441,7 +436,7 @@ class TemplateAPI(TemplateLM):
response
.
raise_for_status
()
response
.
raise_for_status
()
return
response
.
json
()
return
response
.
json
()
except
RetryError
:
except
RetryError
:
eval_logger
.
e
rror
(
eval_logger
.
e
xception
(
"API request failed after multiple retries. Please check the API status."
"API request failed after multiple retries. Please check the API status."
)
)
return
None
return
None
...
@@ -450,14 +445,14 @@ class TemplateAPI(TemplateLM):
...
@@ -450,14 +445,14 @@ class TemplateAPI(TemplateLM):
self
,
self
,
session
:
ClientSession
,
session
:
ClientSession
,
sem
:
asyncio
.
Semaphore
,
sem
:
asyncio
.
Semaphore
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
str
]
,
L
ist
[
JsonChatStr
]
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
str
]
|
l
ist
[
JsonChatStr
],
*
,
*
,
generate
:
bool
=
True
,
generate
:
bool
=
True
,
cache_keys
:
list
=
None
,
cache_keys
:
list
|
None
=
None
,
ctxlens
:
Optional
[
L
ist
[
int
]
]
=
None
,
ctxlens
:
l
ist
[
int
]
|
None
=
None
,
gen_kwargs
:
Optional
[
Dict
]
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
L
ist
[
str
]
,
L
ist
[
T
uple
[
float
,
bool
]]
,
None
]
:
)
->
l
ist
[
str
]
|
l
ist
[
t
uple
[
float
,
bool
]]
|
None
:
# !!! Copy: shared dict for each request, need new object !!!
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
gen_kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
payload
=
self
.
_create_payload
(
payload
=
self
.
_create_payload
(
...
@@ -508,8 +503,8 @@ class TemplateAPI(TemplateLM):
...
@@ -508,8 +503,8 @@ class TemplateAPI(TemplateLM):
sem
.
release
()
sem
.
release
()
def
batch_loglikelihood_requests
(
def
batch_loglikelihood_requests
(
self
,
chunks
:
Iterable
[
L
ist
[
LogLikelihoodInputs
]]
self
,
chunks
:
Iterable
[
l
ist
[
LogLikelihoodInputs
]]
)
->
T
uple
[
L
ist
[
L
ist
[
int
]],
L
ist
[
int
],
L
ist
[
T
uple
[
str
,
str
]]]:
)
->
t
uple
[
l
ist
[
l
ist
[
int
]],
l
ist
[
int
],
l
ist
[
t
uple
[
str
,
str
]]]:
inputs
=
[]
inputs
=
[]
ctxlens
=
[]
ctxlens
=
[]
cache_keys
=
[]
cache_keys
=
[]
...
@@ -536,9 +531,9 @@ class TemplateAPI(TemplateLM):
...
@@ -536,9 +531,9 @@ class TemplateAPI(TemplateLM):
cache_keys
:
list
,
cache_keys
:
list
,
*
,
*
,
generate
:
bool
=
True
,
generate
:
bool
=
True
,
ctxlens
:
L
ist
[
int
]
=
None
,
ctxlens
:
l
ist
[
int
]
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
L
ist
[
L
ist
[
str
]]
,
L
ist
[
L
ist
[
T
uple
[
float
,
bool
]]]
]
:
)
->
l
ist
[
l
ist
[
str
]]
|
l
ist
[
l
ist
[
t
uple
[
float
,
bool
]]]:
ctxlens
=
ctxlens
if
ctxlens
else
[
None
]
*
len
(
requests
)
ctxlens
=
ctxlens
if
ctxlens
else
[
None
]
*
len
(
requests
)
conn
=
TCPConnector
(
limit
=
self
.
_concurrent
,
ssl
=
self
.
verify_certificate
)
conn
=
TCPConnector
(
limit
=
self
.
_concurrent
,
ssl
=
self
.
verify_certificate
)
sem
=
asyncio
.
Semaphore
(
self
.
_concurrent
)
sem
=
asyncio
.
Semaphore
(
self
.
_concurrent
)
...
@@ -575,14 +570,14 @@ class TemplateAPI(TemplateLM):
...
@@ -575,14 +570,14 @@ class TemplateAPI(TemplateLM):
return
await
tqdm_asyncio
.
gather
(
*
tasks
,
desc
=
"Requesting API"
)
return
await
tqdm_asyncio
.
gather
(
*
tasks
,
desc
=
"Requesting API"
)
def
_loglikelihood_tokens
(
self
,
requests
,
**
kwargs
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
def
_loglikelihood_tokens
(
self
,
requests
,
**
kwargs
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
assert
self
.
tokenizer
is
not
None
,
(
assert
self
.
tokenizer
is
not
None
,
(
"Tokenizer is required for loglikelihood tasks to compute context lengths."
"Tokenizer is required for loglikelihood tasks to compute context lengths."
)
)
res
=
[]
res
=
[]
def
_collate
(
req
:
LogLikelihoodInputs
):
def
_collate
(
req
:
LogLikelihoodInputs
):
"""Defines the key for the sorted method"""
"""Defines the key for the sorted method
.
"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# - to know the size of a batch when going through the list, you know the first one is always the batch
...
@@ -639,8 +634,8 @@ class TemplateAPI(TemplateLM):
...
@@ -639,8 +634,8 @@ class TemplateAPI(TemplateLM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
generate_until
(
def
generate_until
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
str
]:
)
->
l
ist
[
str
]:
res
=
[]
res
=
[]
def
_collate_gen
(
_requests
):
def
_collate_gen
(
_requests
):
...
@@ -773,8 +768,8 @@ class TemplateAPI(TemplateLM):
...
@@ -773,8 +768,8 @@ class TemplateAPI(TemplateLM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
loglikelihood_rolling
(
def
loglikelihood_rolling
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
loglikelihoods
=
[]
loglikelihoods
=
[]
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
],
disable
=
disable_tqdm
):
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
],
disable
=
disable_tqdm
):
...
...
lm_eval/models/huggingface.py
View file @
1f97a945
from
__future__
import
annotations
import
copy
import
copy
import
logging
import
logging
import
os
import
os
from
datetime
import
timedelta
from
datetime
import
timedelta
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
import
jinja2
import
jinja2
import
torch
import
torch
...
@@ -40,7 +42,7 @@ from lm_eval.models.utils import (
...
@@ -40,7 +42,7 @@ from lm_eval.models.utils import (
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers.quantizers
import
AutoQuantizationConfig
from
transformers.quantizers
.auto
import
AutoQuantizationConfig
eval_logger
=
logging
.
getLogger
(
__name__
)
eval_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -59,46 +61,43 @@ class HFLM(TemplateLM):
...
@@ -59,46 +61,43 @@ class HFLM(TemplateLM):
def
__init__
(
def
__init__
(
self
,
self
,
pretrained
:
Union
[
str
,
transformers
.
PreTrainedModel
]
,
pretrained
:
str
|
transformers
.
PreTrainedModel
,
backend
:
Literal
[
"default"
,
"causal"
,
"seq2seq"
]
=
"default"
,
backend
:
Literal
[
"default"
,
"causal"
,
"seq2seq"
]
=
"default"
,
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision
:
Optional
[
str
]
=
"main"
,
revision
:
str
|
None
=
"main"
,
subfolder
:
str
=
""
,
subfolder
:
str
=
""
,
tokenizer
:
Optional
[
tokenizer
:
str
Union
[
|
transformers
.
PreTrainedTokenizer
str
,
|
transformers
.
PreTrainedTokenizerFast
transformers
.
PreTrainedTokenizer
,
|
None
=
None
,
transformers
.
PreTrainedTokenizerFast
,
truncation
:
bool
|
None
=
False
,
]
]
=
None
,
truncation
:
Optional
[
bool
]
=
False
,
logits_cache
:
bool
=
True
,
logits_cache
:
bool
=
True
,
max_length
:
Optional
[
int
]
=
None
,
max_length
:
int
|
None
=
None
,
device
:
Optional
[
str
]
=
"cuda"
,
device
:
str
|
None
=
"cuda"
,
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
"auto"
,
dtype
:
str
|
torch
.
dtype
|
None
=
"auto"
,
softmax_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
softmax_dtype
:
str
|
torch
.
dtype
|
None
=
None
,
mixed_precision_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
mixed_precision_dtype
:
str
|
torch
.
dtype
|
None
=
None
,
batch_size
:
Optional
[
Union
[
int
,
str
]]
=
1
,
batch_size
:
int
|
str
|
None
=
1
,
max_batch_size
:
Optional
[
int
]
=
64
,
max_batch_size
:
int
|
None
=
64
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
trust_remote_code
:
bool
|
None
=
False
,
use_fast_tokenizer
:
Optional
[
bool
]
=
True
,
use_fast_tokenizer
:
bool
|
None
=
True
,
add_bos_token
:
Optional
[
bool
]
=
False
,
add_bos_token
:
bool
|
None
=
False
,
prefix_token_id
:
Optional
[
int
]
=
None
,
prefix_token_id
:
int
|
None
=
None
,
# arguments used for splitting a model across GPUs naively.
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
# only used if `parallelize=True`.
parallelize
:
Optional
[
bool
]
=
False
,
parallelize
:
bool
|
None
=
False
,
max_memory_per_gpu
:
Optional
[
Union
[
int
,
str
]]
=
None
,
max_memory_per_gpu
:
int
|
str
|
None
=
None
,
max_cpu_memory
:
Optional
[
Union
[
int
,
str
]]
=
None
,
max_cpu_memory
:
int
|
str
|
None
=
None
,
offload_folder
:
Optional
[
Union
[
str
,
os
.
PathLike
]]
=
"./offload"
,
offload_folder
:
str
|
os
.
PathLike
|
None
=
"./offload"
,
# PEFT, delta weights and quantization options
# PEFT, delta weights and quantization options
peft
:
Optional
[
str
]
=
None
,
peft
:
str
|
None
=
None
,
delta
:
Optional
[
str
]
=
None
,
delta
:
str
|
None
=
None
,
autogptq
:
Optional
[
Union
[
bool
,
str
]]
=
False
,
autogptq
:
bool
|
str
|
None
=
False
,
gptqmodel
:
Optional
[
bool
]
=
False
,
gptqmodel
:
bool
|
None
=
False
,
gguf_file
:
Optional
[
str
]
=
None
,
gguf_file
:
str
|
None
=
None
,
# end token for thinking, either the string or int token id.
# end token for thinking, either the string or int token id.
# splits to get response after this token (if provided).
# splits to get response after this token (if provided).
think_end_token
:
Union
[
str
,
int
,
None
]
=
None
,
think_end_token
:
str
|
int
|
None
=
None
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -271,9 +270,10 @@ class HFLM(TemplateLM):
...
@@ -271,9 +270,10 @@ class HFLM(TemplateLM):
self
.
batch_size_per_gpu
=
int
(
batch_size
)
self
.
batch_size_per_gpu
=
int
(
batch_size
)
if
isinstance
(
pretrained
,
str
):
if
isinstance
(
pretrained
,
str
):
if
gpus
>=
1
or
str
(
self
.
device
)
==
"mps"
:
if
(
gpus
>=
1
or
str
(
self
.
device
)
==
"mps"
)
and
not
(
parallelize
or
autogptq
or
hasattr
(
self
,
"accelerator"
)
):
# TODO: can remove this whole snippet except in the mps case, perhaps?
# TODO: can remove this whole snippet except in the mps case, perhaps?
if
not
(
parallelize
or
autogptq
or
hasattr
(
self
,
"accelerator"
)):
# place model onto device requested manually,
# place model onto device requested manually,
# if not using HF Accelerate or device_map
# if not using HF Accelerate or device_map
# or any other option that preloads model onto device
# or any other option that preloads model onto device
...
@@ -327,12 +327,12 @@ class HFLM(TemplateLM):
...
@@ -327,12 +327,12 @@ class HFLM(TemplateLM):
def
_get_accelerate_args
(
def
_get_accelerate_args
(
self
,
self
,
parallelize
:
Optional
[
bool
]
=
None
,
parallelize
:
bool
|
None
=
None
,
device_map
:
Optional
[
str
]
=
"auto"
,
device_map
:
str
|
None
=
"auto"
,
max_memory_per_gpu
:
Optional
[
Union
[
int
,
str
]]
=
None
,
max_memory_per_gpu
:
int
|
str
|
None
=
None
,
max_cpu_memory
:
Optional
[
Union
[
int
,
str
]]
=
None
,
max_cpu_memory
:
int
|
str
|
None
=
None
,
offload_folder
:
Optional
[
str
]
=
"./offload"
,
offload_folder
:
str
|
None
=
"./offload"
,
gpus
:
Optional
[
int
]
=
None
,
gpus
:
int
|
None
=
None
,
)
->
dict
:
)
->
dict
:
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
num_local_processes
=
int
(
os
.
environ
.
get
(
"LOCAL_WORLD_SIZE"
,
1
))
num_local_processes
=
int
(
os
.
environ
.
get
(
"LOCAL_WORLD_SIZE"
,
1
))
...
@@ -480,9 +480,9 @@ class HFLM(TemplateLM):
...
@@ -480,9 +480,9 @@ class HFLM(TemplateLM):
def
_get_backend
(
def
_get_backend
(
self
,
self
,
config
:
Union
[
transformers
.
PretrainedConfig
,
transformers
.
AutoConfig
]
,
config
:
transformers
.
PretrainedConfig
|
transformers
.
AutoConfig
,
backend
:
Literal
[
"default"
,
"causal"
,
"seq2seq"
]
=
"default"
,
backend
:
Literal
[
"default"
,
"causal"
,
"seq2seq"
]
=
"default"
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
trust_remote_code
:
bool
|
None
=
False
,
)
->
None
:
)
->
None
:
"""
"""
Helper method during initialization.
Helper method during initialization.
...
@@ -497,27 +497,20 @@ class HFLM(TemplateLM):
...
@@ -497,27 +497,20 @@ class HFLM(TemplateLM):
if
backend
!=
"default"
:
if
backend
!=
"default"
:
# if we've settled on non-default backend, use that manually
# if we've settled on non-default backend, use that manually
if
backend
==
"causal"
:
if
backend
in
[
"causal"
,
"seq2seq"
]:
self
.
backend
=
backend
elif
backend
==
"seq2seq"
:
self
.
backend
=
backend
self
.
backend
=
backend
eval_logger
.
info
(
eval_logger
.
info
(
f
"Overrode HF model backend type, and using type '
{
self
.
backend
}
'"
f
"Overrode HF model backend type, and using type '
{
self
.
backend
}
'"
)
)
else
:
else
:
# determine and use the default HF backend for this model, based on its config + metadata.
# determine and use the default HF backend for this model, based on its config + metadata.
if
(
if
self
.
config
.
model_type
in
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
:
getattr
(
config
,
"model_type"
)
in
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
):
# first check if model type is listed under seq2seq models, since some
# first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models.
# these special cases should be treated as seq2seq models.
self
.
backend
=
"seq2seq"
self
.
backend
=
"seq2seq"
eval_logger
.
debug
(
f
"Using model type '
{
self
.
backend
}
'"
)
eval_logger
.
debug
(
f
"Using model type '
{
self
.
backend
}
'"
)
elif
(
elif
self
.
config
.
model_type
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
getattr
(
self
.
config
,
"model_type"
)
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
):
self
.
backend
=
"causal"
self
.
backend
=
"causal"
eval_logger
.
debug
(
f
"Using model type '
{
self
.
backend
}
'"
)
eval_logger
.
debug
(
f
"Using model type '
{
self
.
backend
}
'"
)
else
:
else
:
...
@@ -545,7 +538,7 @@ class HFLM(TemplateLM):
...
@@ -545,7 +538,7 @@ class HFLM(TemplateLM):
pretrained
:
str
,
pretrained
:
str
,
revision
:
str
=
"main"
,
revision
:
str
=
"main"
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
gguf_file
:
Optional
[
str
]
=
None
,
gguf_file
:
str
|
None
=
None
,
subfolder
:
str
=
""
,
subfolder
:
str
=
""
,
)
->
None
:
)
->
None
:
"""Return the model config for HuggingFace models"""
"""Return the model config for HuggingFace models"""
...
@@ -560,24 +553,24 @@ class HFLM(TemplateLM):
...
@@ -560,24 +553,24 @@ class HFLM(TemplateLM):
def
_create_model
(
def
_create_model
(
self
,
self
,
pretrained
:
str
,
pretrained
:
str
,
revision
:
Optional
[
str
]
=
"main"
,
revision
:
str
|
None
=
"main"
,
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
"auto"
,
dtype
:
str
|
torch
.
dtype
|
None
=
"auto"
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
trust_remote_code
:
bool
|
None
=
False
,
# arguments used for splitting a model across GPUs naively.
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
# only used if `parallelize=True`.
# (accelerate naive PP (device_map) options)
# (accelerate naive PP (device_map) options)
parallelize
:
Optional
[
bool
]
=
False
,
parallelize
:
bool
|
None
=
False
,
gpus
:
Optional
[
int
]
=
None
,
gpus
:
int
|
None
=
None
,
max_memory_per_gpu
:
Optional
[
Union
[
int
,
str
]]
=
None
,
max_memory_per_gpu
:
int
|
str
|
None
=
None
,
max_cpu_memory
:
Optional
[
Union
[
int
,
str
]]
=
None
,
max_cpu_memory
:
int
|
str
|
None
=
None
,
offload_folder
:
Optional
[
str
]
=
"./offload"
,
offload_folder
:
str
|
None
=
"./offload"
,
# PEFT, delta weights and quantization options
# PEFT, delta weights and quantization options
peft
:
Optional
[
str
]
=
None
,
peft
:
str
|
None
=
None
,
delta
:
Optional
[
str
]
=
None
,
delta
:
str
|
None
=
None
,
autogptq
:
Optional
[
Union
[
bool
,
str
]]
=
False
,
autogptq
:
bool
|
str
|
None
=
False
,
gptqmodel
:
Optional
[
bool
]
=
False
,
gptqmodel
:
bool
|
None
=
False
,
gguf_file
:
Optional
[
str
]
=
None
,
gguf_file
:
str
|
None
=
None
,
quantization_config
:
Optional
[
"
AutoQuantizationConfig
"
]
=
None
,
quantization_config
:
AutoQuantizationConfig
|
None
=
None
,
subfolder
:
str
=
""
,
subfolder
:
str
=
""
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
...
@@ -598,7 +591,7 @@ class HFLM(TemplateLM):
...
@@ -598,7 +591,7 @@ class HFLM(TemplateLM):
model_kwargs
.
update
(
model_kwargs
.
update
(
self
.
_get_accelerate_args
(
self
.
_get_accelerate_args
(
parallelize
=
parallelize
,
parallelize
=
parallelize
,
device_map
=
kwargs
.
get
(
"device_map"
,
None
),
device_map
=
kwargs
.
get
(
"device_map"
),
max_memory_per_gpu
=
max_memory_per_gpu
,
max_memory_per_gpu
=
max_memory_per_gpu
,
max_cpu_memory
=
max_cpu_memory
,
max_cpu_memory
=
max_cpu_memory
,
offload_folder
=
offload_folder
,
offload_folder
=
offload_folder
,
...
@@ -611,12 +604,11 @@ class HFLM(TemplateLM):
...
@@ -611,12 +604,11 @@ class HFLM(TemplateLM):
assert
transformers
.
__version__
>=
"4.30.0"
,
(
assert
transformers
.
__version__
>=
"4.30.0"
,
(
"load_in_4bit requires transformers >= 4.30.0"
"load_in_4bit requires transformers >= 4.30.0"
)
)
if
transformers
.
__version__
>=
"4.30.0"
:
if
transformers
.
__version__
>=
"4.30.0"
and
(
if
model_kwargs
.
get
(
"load_in_4bit"
,
None
):
model_kwargs
.
get
(
"load_in_4bit"
)
if
model_kwargs
.
get
(
"bnb_4bit_compute_dtype"
,
None
):
and
(
compute_dtype
:
=
model_kwargs
.
get
(
"bnb_4bit_compute_dtype"
))
model_kwargs
[
"bnb_4bit_compute_dtype"
]
=
get_dtype
(
):
model_kwargs
[
"bnb_4bit_compute_dtype"
]
model_kwargs
[
"bnb_4bit_compute_dtype"
]
=
get_dtype
(
compute_dtype
)
)
self
.
_model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
self
.
_model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
pretrained
,
pretrained
,
...
@@ -641,7 +633,7 @@ class HFLM(TemplateLM):
...
@@ -641,7 +633,7 @@ class HFLM(TemplateLM):
raise
type
(
exception
)(
raise
type
(
exception
)(
"Tried to load auto_gptq, but auto-gptq is not installed "
,
"Tried to load auto_gptq, but auto-gptq is not installed "
,
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]"
,
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]"
,
)
)
from
exception
self
.
_model
=
AutoGPTQForCausalLM
.
from_quantized
(
self
.
_model
=
AutoGPTQForCausalLM
.
from_quantized
(
pretrained
,
pretrained
,
...
@@ -660,7 +652,7 @@ class HFLM(TemplateLM):
...
@@ -660,7 +652,7 @@ class HFLM(TemplateLM):
raise
type
(
exception
)(
raise
type
(
exception
)(
"Tried to load gptqmodel, but gptqmodel is not installed "
,
"Tried to load gptqmodel, but gptqmodel is not installed "
,
"please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`"
,
"please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`"
,
)
)
from
exception
self
.
_model
=
GPTQModel
.
from_quantized
(
self
.
_model
=
GPTQModel
.
from_quantized
(
pretrained
,
trust_remote_code
=
trust_remote_code
,
**
model_kwargs
pretrained
,
trust_remote_code
=
trust_remote_code
,
**
model_kwargs
...
@@ -672,11 +664,11 @@ class HFLM(TemplateLM):
...
@@ -672,11 +664,11 @@ class HFLM(TemplateLM):
)
)
if
peft
:
if
peft
:
from
peft
import
PeftModel
from
peft
import
PeftModel
,
__version__
as
PEFT_VERSION
from
peft
import
__version__
as
PEFT_VERSION
if
model_kwargs
.
get
(
"load_in_4bit"
,
None
):
if
model_kwargs
.
get
(
"load_in_4bit"
)
and
version
.
parse
(
if
version
.
parse
(
PEFT_VERSION
)
<
version
.
parse
(
"0.4.0"
):
PEFT_VERSION
)
<
version
.
parse
(
"0.4.0"
):
raise
AssertionError
(
"load_in_4bit requires peft >= 0.4.0"
)
raise
AssertionError
(
"load_in_4bit requires peft >= 0.4.0"
)
if
self
.
_model
.
config
.
vocab_size
!=
len
(
self
.
tokenizer
):
if
self
.
_model
.
config
.
vocab_size
!=
len
(
self
.
tokenizer
):
# resize model for LoRAs with added tokens
# resize model for LoRAs with added tokens
...
@@ -703,11 +695,13 @@ class HFLM(TemplateLM):
...
@@ -703,11 +695,13 @@ class HFLM(TemplateLM):
try
:
try
:
param
.
data
+=
_model_delta
.
state_dict
()[
name
]
param
.
data
+=
_model_delta
.
state_dict
()[
name
]
except
KeyError
:
except
KeyError
:
raise
KeyError
(
f
"Delta model is missing weights for layer:
{
name
}
"
)
raise
KeyError
(
f
"Delta model is missing weights for layer:
{
name
}
"
)
from
None
except
Exception
as
e
:
except
Exception
as
e
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Failed to add delta weights to layer
{
name
}
. Error:
{
e
}
"
f
"Failed to add delta weights to layer
{
name
}
. Error:
{
e
}
"
)
)
from
e
del
_model_delta
del
_model_delta
...
@@ -715,20 +709,17 @@ class HFLM(TemplateLM):
...
@@ -715,20 +709,17 @@ class HFLM(TemplateLM):
def
_create_tokenizer
(
def
_create_tokenizer
(
self
,
self
,
pretrained
:
Union
[
str
,
transformers
.
PreTrainedModel
],
pretrained
:
str
|
transformers
.
PreTrainedModel
,
tokenizer
:
Optional
[
tokenizer
:
str
Union
[
|
transformers
.
PreTrainedTokenizer
str
,
|
transformers
.
PreTrainedTokenizerFast
transformers
.
PreTrainedTokenizer
,
|
None
,
transformers
.
PreTrainedTokenizerFast
,
revision
:
str
|
None
=
"main"
,
]
trust_remote_code
:
bool
|
None
=
False
,
],
use_fast_tokenizer
:
bool
|
None
=
True
,
revision
:
Optional
[
str
]
=
"main"
,
gguf_file
:
str
|
None
=
None
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
add_bos_token
:
bool
|
None
=
False
,
use_fast_tokenizer
:
Optional
[
bool
]
=
True
,
subfolder
:
str
|
None
=
""
,
gguf_file
:
Optional
[
str
]
=
None
,
add_bos_token
:
Optional
[
bool
]
=
False
,
subfolder
:
Optional
[
str
]
=
""
,
)
->
None
:
)
->
None
:
"""
"""
Helper method during initialization.
Helper method during initialization.
...
@@ -760,8 +751,12 @@ class HFLM(TemplateLM):
...
@@ -760,8 +751,12 @@ class HFLM(TemplateLM):
)
)
else
:
else
:
assert
isinstance
(
assert
isinstance
(
tokenizer
,
transformers
.
PreTrainedTokenizer
tokenizer
,
)
or
isinstance
(
tokenizer
,
transformers
.
PreTrainedTokenizerFast
)
(
transformers
.
PreTrainedTokenizer
,
transformers
.
PreTrainedTokenizerFast
,
),
)
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
else
:
else
:
# Get tokenizer based on 'pretrained'
# Get tokenizer based on 'pretrained'
...
@@ -838,7 +833,7 @@ class HFLM(TemplateLM):
...
@@ -838,7 +833,7 @@ class HFLM(TemplateLM):
def
tok_encode
(
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
=
None
,
add_special_tokens
=
None
self
,
string
:
str
,
left_truncate_len
=
None
,
add_special_tokens
=
None
)
->
L
ist
[
int
]:
)
->
l
ist
[
int
]:
""" """
""" """
# default for None - empty dict, use predefined tokenizer param
# default for None - empty dict, use predefined tokenizer param
# used for all models except for CausalLM or predefined value
# used for all models except for CausalLM or predefined value
...
@@ -864,11 +859,11 @@ class HFLM(TemplateLM):
...
@@ -864,11 +859,11 @@ class HFLM(TemplateLM):
def
tok_batch_encode
(
def
tok_batch_encode
(
self
,
self
,
strings
:
L
ist
[
str
],
strings
:
l
ist
[
str
],
padding_side
:
str
=
"left"
,
padding_side
:
str
=
"left"
,
left_truncate_len
:
int
=
None
,
left_truncate_len
:
int
=
None
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side
=
self
.
tokenizer
.
padding_side
old_padding_side
=
self
.
tokenizer
.
padding_side
self
.
tokenizer
.
padding_side
=
padding_side
self
.
tokenizer
.
padding_side
=
padding_side
...
@@ -917,15 +912,17 @@ class HFLM(TemplateLM):
...
@@ -917,15 +912,17 @@ class HFLM(TemplateLM):
A torch tensor of shape [batch, sequence, vocab] with the
A torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model's decoder
logits returned from the model's decoder
"""
"""
with
torch
.
no_grad
():
with
(
with
torch
.
autocast
(
torch
.
no_grad
(),
torch
.
autocast
(
device_type
=
self
.
device
.
type
,
device_type
=
self
.
device
.
type
,
dtype
=
self
.
mixed_precision_dtype
,
dtype
=
self
.
mixed_precision_dtype
,
enabled
=
self
.
mixed_precision_dtype
is
not
None
,
enabled
=
self
.
mixed_precision_dtype
is
not
None
,
),
):
):
if
attn_mask
is
not
None
or
labels
is
not
None
:
if
attn_mask
is
not
None
or
labels
is
not
None
:
assert
attn_mask
is
not
None
and
labels
is
not
None
assert
attn_mask
is
not
None
and
labels
is
not
None
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
assert
transformers
.
AutoModelForSeq2SeqLM
==
self
.
AUTO_MODEL_CLASS
return
self
.
model
(
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
).
logits
...
@@ -942,7 +939,7 @@ class HFLM(TemplateLM):
...
@@ -942,7 +939,7 @@ class HFLM(TemplateLM):
# remove temperature, as do_sample=False takes care of this
# remove temperature, as do_sample=False takes care of this
# and we don't want a warning from HF
# and we don't want a warning from HF
generation_kwargs
[
"temperature"
]
=
generation_kwargs
.
get
(
"temperature"
,
0.0
)
generation_kwargs
[
"temperature"
]
=
generation_kwargs
.
get
(
"temperature"
,
0.0
)
do_sample
=
generation_kwargs
.
get
(
"do_sample"
,
None
)
do_sample
=
generation_kwargs
.
get
(
"do_sample"
)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if
generation_kwargs
.
get
(
"temperature"
)
==
0.0
and
do_sample
is
None
:
if
generation_kwargs
.
get
(
"temperature"
)
==
0.0
and
do_sample
is
None
:
...
@@ -989,8 +986,8 @@ class HFLM(TemplateLM):
...
@@ -989,8 +986,8 @@ class HFLM(TemplateLM):
return
logits
return
logits
def
loglikelihood_rolling
(
def
loglikelihood_rolling
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
adaptive_batch_size
=
None
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
if
self
.
batch_size
==
"auto"
:
# using rolling window with maximum context
# using rolling window with maximum context
...
@@ -1009,7 +1006,7 @@ class HFLM(TemplateLM):
...
@@ -1009,7 +1006,7 @@ class HFLM(TemplateLM):
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
)
)
):
):
rolling_token_windows
:
L
ist
[
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]]
=
list
(
rolling_token_windows
:
l
ist
[
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]]
=
list
(
map
(
map
(
utils
.
make_disjoint_window
,
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
utils
.
get_rolling_token_windows
(
...
@@ -1093,14 +1090,14 @@ class HFLM(TemplateLM):
...
@@ -1093,14 +1090,14 @@ class HFLM(TemplateLM):
def
_loglikelihood_tokens
(
def
_loglikelihood_tokens
(
self
,
self
,
requests
:
L
ist
[
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]],
requests
:
l
ist
[
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]],
disable_tqdm
:
bool
=
False
,
disable_tqdm
:
bool
=
False
,
override_bs
:
int
=
None
,
override_bs
:
int
=
None
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
res
=
[]
def
_collate
(
req
:
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]):
def
_collate
(
req
:
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]):
"""Defines the key for the sorted method"""
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - time estimates will always be over not underestimates, which is more useful for planning
...
@@ -1112,7 +1109,7 @@ class HFLM(TemplateLM):
...
@@ -1112,7 +1109,7 @@ class HFLM(TemplateLM):
toks
=
req
[
1
]
+
req
[
2
]
toks
=
req
[
1
]
+
req
[
2
]
return
-
len
(
toks
),
tuple
(
toks
)
return
-
len
(
toks
),
tuple
(
toks
)
def
_lookup_one_token_cont
(
req
:
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]):
def
_lookup_one_token_cont
(
req
:
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]):
"""Defines the key to group and lookup one-token continuations"""
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
...
@@ -1286,7 +1283,7 @@ class HFLM(TemplateLM):
...
@@ -1286,7 +1283,7 @@ class HFLM(TemplateLM):
# original args. Otherwise, expands the logits batch dimension and yields each
# original args. Otherwise, expands the logits batch dimension and yields each
# batch along with matching continuation tokens and prompt strings.
# batch along with matching continuation tokens and prompt strings.
# logits -> [1, seq, vocab]
# logits -> [1, seq, vocab]
for
request_str
,
cont_toks
,
logits
in
re_ord
.
get_cache
(
for
request_str
,
cont_toks
,
logits
in
re_ord
.
get_cache
(
# noqa
req_str
=
request_str
,
req_str
=
request_str
,
cxt_toks
=
ctx_tokens
,
cxt_toks
=
ctx_tokens
,
cont_toks
=
cont_toks
,
cont_toks
=
cont_toks
,
...
@@ -1327,11 +1324,11 @@ class HFLM(TemplateLM):
...
@@ -1327,11 +1324,11 @@ class HFLM(TemplateLM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
generate_until
(
def
generate_until
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
str
]:
)
->
l
ist
[
str
]:
res
=
[]
res
=
[]
def
_collate
(
req
:
T
uple
[
str
,
dict
]):
def
_collate
(
req
:
t
uple
[
str
,
dict
]):
"""Defines the key for the sorted method"""
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - time estimates will always be over not underestimates, which is more useful for planning
...
@@ -1394,7 +1391,7 @@ class HFLM(TemplateLM):
...
@@ -1394,7 +1391,7 @@ class HFLM(TemplateLM):
raise
ValueError
(
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
)
if
"max_gen_toks"
in
kwargs
.
keys
()
:
if
"max_gen_toks"
in
kwargs
:
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
else
:
max_gen_toks
=
self
.
max_gen_toks
max_gen_toks
=
self
.
max_gen_toks
...
@@ -1472,7 +1469,7 @@ class HFLM(TemplateLM):
...
@@ -1472,7 +1469,7 @@ class HFLM(TemplateLM):
return
res
return
res
def
apply_chat_template
(
def
apply_chat_template
(
self
,
chat_history
:
L
ist
[
D
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
self
,
chat_history
:
l
ist
[
d
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
)
->
str
:
)
->
str
:
"""
"""
Method to apply a chat template to a list of chat history between user and model.
Method to apply a chat template to a list of chat history between user and model.
...
...
lm_eval/models/openai_completions.py
View file @
1f97a945
from
__future__
import
annotations
import
logging
import
logging
import
os
import
os
from
functools
import
cached_property
from
functools
import
cached_property
from
operator
import
itemgetter
from
operator
import
itemgetter
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.api_models
import
TemplateAPI
from
lm_eval.models.api_models
import
TemplateAPI
...
@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI):
...
@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI):
def
_create_payload
(
def
_create_payload
(
self
,
self
,
messages
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
dict
]
,
L
ist
[
str
]
,
str
]
,
messages
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
dict
]
|
l
ist
[
str
]
|
str
,
generate
=
False
,
generate
=
False
,
gen_kwargs
:
Optional
[
dict
]
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
seed
:
int
=
1234
,
seed
:
int
=
1234
,
eos
=
None
,
eos
=
None
,
**
kwargs
,
**
kwargs
,
...
@@ -50,7 +52,6 @@ class LocalCompletionsAPI(TemplateAPI):
...
@@ -50,7 +52,6 @@ class LocalCompletionsAPI(TemplateAPI):
"seed"
:
seed
,
"seed"
:
seed
,
**
gen_kwargs
,
**
gen_kwargs
,
}
}
else
:
return
{
return
{
"model"
:
self
.
model
,
"model"
:
self
.
model
,
"prompt"
:
messages
,
"prompt"
:
messages
,
...
@@ -63,11 +64,11 @@ class LocalCompletionsAPI(TemplateAPI):
...
@@ -63,11 +64,11 @@ class LocalCompletionsAPI(TemplateAPI):
@
staticmethod
@
staticmethod
def
parse_logprobs
(
def
parse_logprobs
(
outputs
:
Union
[
Dict
,
L
ist
[
D
ict
]
]
,
outputs
:
dict
|
l
ist
[
d
ict
],
tokens
:
L
ist
[
L
ist
[
int
]]
=
None
,
tokens
:
l
ist
[
l
ist
[
int
]]
=
None
,
ctxlens
:
L
ist
[
int
]
=
None
,
ctxlens
:
l
ist
[
int
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
res
=
[]
res
=
[]
if
not
isinstance
(
outputs
,
list
):
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
...
@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI):
...
@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI):
return
res
return
res
@
staticmethod
@
staticmethod
def
parse_generations
(
outputs
:
Union
[
Dict
,
L
ist
[
D
ict
]
]
,
**
kwargs
)
->
L
ist
[
str
]:
def
parse_generations
(
outputs
:
dict
|
l
ist
[
d
ict
],
**
kwargs
)
->
l
ist
[
str
]:
res
=
[]
res
=
[]
if
not
isinstance
(
outputs
,
list
):
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
...
@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
...
@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
def
_create_payload
(
def
_create_payload
(
self
,
self
,
messages
:
L
ist
[
D
ict
],
messages
:
l
ist
[
d
ict
],
generate
=
False
,
generate
=
False
,
gen_kwargs
:
dict
=
None
,
gen_kwargs
:
dict
|
None
=
None
,
seed
=
1234
,
seed
=
1234
,
eos
=
None
,
eos
=
None
,
**
kwargs
,
**
kwargs
,
...
@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
...
@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
}
}
@
staticmethod
@
staticmethod
def
parse_generations
(
outputs
:
Union
[
Dict
,
L
ist
[
D
ict
]
]
,
**
kwargs
)
->
L
ist
[
str
]:
def
parse_generations
(
outputs
:
dict
|
l
ist
[
d
ict
],
**
kwargs
)
->
l
ist
[
str
]:
res
=
[]
res
=
[]
if
not
isinstance
(
outputs
,
list
):
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
...
@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI):
...
@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI):
def
tok_encode
(
def
tok_encode
(
self
,
self
,
string
:
Union
[
str
,
Any
]
,
string
:
str
|
Any
,
left_truncate_len
=
None
,
left_truncate_len
=
None
,
add_special_tokens
=
None
,
add_special_tokens
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
L
ist
[
str
]
,
L
ist
[
int
]
,
Any
]
:
)
->
l
ist
[
str
]
|
l
ist
[
int
]
|
Any
:
return
string
return
string
def
loglikelihood
(
self
,
requests
,
**
kwargs
):
def
loglikelihood
(
self
,
requests
,
**
kwargs
):
...
@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
...
@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
)
)
return
super
().
loglikelihood
(
requests
,
**
kwargs
)
return
super
().
loglikelihood
(
requests
,
**
kwargs
)
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]
:
def
chat_template
(
self
,
chat_template
:
bool
|
str
=
False
)
->
str
|
None
:
return
""
return
""
...
@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
...
@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
def
_create_payload
(
def
_create_payload
(
self
,
self
,
messages
:
L
ist
[
D
ict
],
messages
:
l
ist
[
d
ict
],
generate
=
False
,
generate
=
False
,
gen_kwargs
:
dict
=
None
,
gen_kwargs
:
dict
=
None
,
seed
=
1234
,
seed
=
1234
,
...
...
lm_eval/models/vllm_causallms.py
View file @
1f97a945
from
__future__
import
annotations
import
copy
import
copy
import
gc
import
gc
import
inspect
import
inspect
...
@@ -8,7 +10,7 @@ from importlib.util import find_spec
...
@@ -8,7 +10,7 @@ from importlib.util import find_spec
from
multiprocessing
import
Process
,
Queue
from
multiprocessing
import
Process
,
Queue
from
queue
import
Empty
from
queue
import
Empty
from
time
import
sleep
from
time
import
sleep
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
import
jinja2
import
jinja2
from
more_itertools
import
distribute
from
more_itertools
import
distribute
...
@@ -51,10 +53,10 @@ eval_logger = logging.getLogger(__name__)
...
@@ -51,10 +53,10 @@ eval_logger = logging.getLogger(__name__)
def
_vllm_mp_worker
(
def
_vllm_mp_worker
(
model_args
:
dict
,
model_args
:
dict
,
sampling_params
:
"
SamplingParams
"
,
sampling_params
:
SamplingParams
,
requests
:
list
[
list
[
int
]],
requests
:
list
[
list
[
int
]],
lora_request
:
"
LoRARequest
"
,
lora_request
:
LoRARequest
,
result_queue
:
"
Queue
"
,
result_queue
:
Queue
,
dp_size
:
int
,
dp_size
:
int
,
local_dp_rank
:
int
,
local_dp_rank
:
int
,
dp_master_port
:
int
,
dp_master_port
:
int
,
...
@@ -114,30 +116,30 @@ class VLLM(TemplateLM):
...
@@ -114,30 +116,30 @@ class VLLM(TemplateLM):
self
,
self
,
pretrained
:
str
,
pretrained
:
str
,
dtype
:
Literal
[
"float16"
,
"bfloat16"
,
"float32"
,
"auto"
]
=
"auto"
,
dtype
:
Literal
[
"float16"
,
"bfloat16"
,
"float32"
,
"auto"
]
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
revision
:
str
|
None
=
None
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
trust_remote_code
:
bool
|
None
=
False
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer
:
str
|
None
=
None
,
tokenizer_mode
:
Literal
[
"auto"
,
"slow"
]
=
"auto"
,
tokenizer_mode
:
Literal
[
"auto"
,
"slow"
]
=
"auto"
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
tokenizer_revision
:
str
|
None
=
None
,
add_bos_token
:
Optional
[
bool
]
=
False
,
add_bos_token
:
bool
|
None
=
False
,
prefix_token_id
:
Optional
[
int
]
=
None
,
prefix_token_id
:
int
|
None
=
None
,
tensor_parallel_size
:
int
=
1
,
tensor_parallel_size
:
int
=
1
,
quantization
:
Optional
[
str
]
=
None
,
quantization
:
str
|
None
=
None
,
max_gen_toks
:
int
=
256
,
max_gen_toks
:
int
=
256
,
swap_space
:
int
=
4
,
swap_space
:
int
=
4
,
batch_size
:
Union
[
str
,
int
]
=
1
,
batch_size
:
str
|
int
=
1
,
max_batch_size
=
None
,
max_batch_size
:
int
|
None
=
None
,
max_length
:
int
=
None
,
max_length
:
int
|
None
=
None
,
max_model_len
:
int
=
None
,
max_model_len
:
int
|
None
=
None
,
seed
:
int
=
1234
,
seed
:
int
=
1234
,
gpu_memory_utilization
:
float
=
0.9
,
gpu_memory_utilization
:
float
=
0.9
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
data_parallel_size
:
int
=
1
,
data_parallel_size
:
int
=
1
,
lora_local_path
:
str
=
None
,
lora_local_path
:
str
|
None
=
None
,
# VLLM: enable thinking tags in the prompt.
# VLLM: enable thinking tags in the prompt.
enable_thinking
:
bool
=
True
,
enable_thinking
:
bool
=
True
,
# End marker for thinking tags - splits to get response after this token (if provided).
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token
:
Optional
[
str
]
=
None
,
think_end_token
:
str
|
None
=
None
,
max_lora_rank
:
int
=
16
,
max_lora_rank
:
int
=
16
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -173,7 +175,7 @@ class VLLM(TemplateLM):
...
@@ -173,7 +175,7 @@ class VLLM(TemplateLM):
"quantization"
:
quantization
,
"quantization"
:
quantization
,
"seed"
:
int
(
seed
),
"seed"
:
int
(
seed
),
"device"
:
str
(
device
),
"device"
:
str
(
device
),
"enable_lora"
:
True
if
lora_local_path
else
False
,
"enable_lora"
:
bool
(
lora_local_path
)
,
"max_lora_rank"
:
int
(
max_lora_rank
),
"max_lora_rank"
:
int
(
max_lora_rank
),
}
}
self
.
model_args
.
update
(
kwargs
)
self
.
model_args
.
update
(
kwargs
)
...
@@ -304,7 +306,7 @@ class VLLM(TemplateLM):
...
@@ -304,7 +306,7 @@ class VLLM(TemplateLM):
return
self
.
_max_gen_toks
return
self
.
_max_gen_toks
def
apply_chat_template
(
def
apply_chat_template
(
self
,
chat_history
:
L
ist
[
D
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
self
,
chat_history
:
l
ist
[
d
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
)
->
str
:
)
->
str
:
"""
"""
Method to apply a chat template to a list of chat history between user and model.
Method to apply a chat template to a list of chat history between user and model.
...
@@ -339,14 +341,14 @@ class VLLM(TemplateLM):
...
@@ -339,14 +341,14 @@ class VLLM(TemplateLM):
def
tok_encode
(
def
tok_encode
(
self
,
self
,
string
:
Union
[
str
,
L
ist
[
str
]
]
,
string
:
str
|
l
ist
[
str
],
left_truncate_len
:
int
=
None
,
left_truncate_len
:
int
=
None
,
add_special_tokens
:
bool
=
False
,
add_special_tokens
:
bool
=
False
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
)
->
Union
[
L
ist
[
int
]
,
L
ist
[
L
ist
[
int
]]
]
:
)
->
l
ist
[
int
]
|
l
ist
[
l
ist
[
int
]]:
if
not
add_special_tokens
:
if
not
add_special_tokens
:
add_special_tokens
=
False
or
self
.
add_bos_token
add_special_tokens
=
False
or
self
.
add_bos_token
encoding
:
Union
[
L
ist
[
L
ist
[
int
]]
,
L
ist
[
int
]
]
=
self
.
tokenizer
(
encoding
:
l
ist
[
l
ist
[
int
]]
|
l
ist
[
int
]
=
self
.
tokenizer
(
string
,
string
,
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
truncation
=
truncation
,
truncation
=
truncation
,
...
@@ -364,10 +366,10 @@ class VLLM(TemplateLM):
...
@@ -364,10 +366,10 @@ class VLLM(TemplateLM):
def
_model_generate
(
def
_model_generate
(
self
,
self
,
requests
:
L
ist
[
L
ist
[
int
]]
=
None
,
requests
:
l
ist
[
l
ist
[
int
]]
=
None
,
generate
:
bool
=
False
,
generate
:
bool
=
False
,
max_tokens
:
int
=
None
,
max_tokens
:
int
=
None
,
stop
:
Optional
[
L
ist
[
str
]
]
=
None
,
stop
:
l
ist
[
str
]
|
None
=
None
,
**
kwargs
,
**
kwargs
,
):
):
if
generate
:
if
generate
:
...
@@ -385,7 +387,7 @@ class VLLM(TemplateLM):
...
@@ -385,7 +387,7 @@ class VLLM(TemplateLM):
def
run_inference_one_model
(
def
run_inference_one_model
(
model_args
:
dict
,
model_args
:
dict
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
requests
:
L
ist
[
L
ist
[
int
]],
requests
:
l
ist
[
l
ist
[
int
]],
lora_request
:
LoRARequest
,
lora_request
:
LoRARequest
,
):
):
llm
=
LLM
(
**
model_args
)
llm
=
LLM
(
**
model_args
)
...
@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
...
@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
if
dead_procs
:
if
dead_procs
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Worker processes
{
dead_procs
}
died unexpectedly"
f
"Worker processes
{
dead_procs
}
died unexpectedly"
)
)
from
None
continue
continue
results
=
[
rank_res
[
i
]
for
i
in
range
(
len
(
procs
))]
results
=
[
rank_res
[
i
]
for
i
in
range
(
len
(
procs
))]
...
@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
...
@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
outputs
=
self
.
model
.
generate
(
outputs
=
self
.
model
.
generate
(
prompt_token_ids
=
requests
,
prompt_token_ids
=
requests
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
if
self
.
batch_size
==
"auto"
else
False
,
use_tqdm
=
self
.
batch_size
==
"auto"
,
lora_request
=
self
.
lora_request
,
lora_request
=
self
.
lora_request
,
)
)
return
outputs
return
outputs
def
loglikelihood_rolling
(
def
loglikelihood_rolling
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
adaptive_batch_size
=
None
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
if
self
.
batch_size
==
"auto"
:
adaptive_batch_size
=
len
(
requests
)
adaptive_batch_size
=
len
(
requests
)
...
@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
...
@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
)
)
):
):
rolling_token_windows
:
L
ist
[
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]]
=
list
(
rolling_token_windows
:
l
ist
[
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]]
=
list
(
map
(
map
(
make_disjoint_window
,
make_disjoint_window
,
get_rolling_token_windows
(
get_rolling_token_windows
(
...
@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
...
@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
return
loglikelihoods
return
loglikelihoods
def
generate_until
(
def
generate_until
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
str
]:
)
->
l
ist
[
str
]:
res
=
[]
res
=
[]
# batch tokenize contexts
# batch tokenize contexts
context
,
all_gen_kwargs
=
zip
(
*
(
req
.
args
for
req
in
requests
))
context
,
all_gen_kwargs
=
zip
(
*
(
req
.
args
for
req
in
requests
))
context_encoding
:
L
ist
[
L
ist
[
int
]]
=
self
.
tok_encode
(
context_encoding
:
l
ist
[
l
ist
[
int
]]
=
self
.
tok_encode
(
context
,
add_special_tokens
=
self
.
add_bos_token
context
,
add_special_tokens
=
self
.
add_bos_token
)
)
requests
=
[
requests
=
[
...
@@ -608,7 +610,7 @@ class VLLM(TemplateLM):
...
@@ -608,7 +610,7 @@ class VLLM(TemplateLM):
raise
ValueError
(
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
)
if
"max_gen_toks"
in
kwargs
.
keys
()
:
if
"max_gen_toks"
in
kwargs
:
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
else
:
max_gen_toks
=
self
.
max_gen_toks
max_gen_toks
=
self
.
max_gen_toks
...
@@ -634,7 +636,7 @@ class VLLM(TemplateLM):
...
@@ -634,7 +636,7 @@ class VLLM(TemplateLM):
)
)
# cache generations
# cache generations
for
output
,
context
in
zip
(
cont
,
context
):
for
output
,
context
_
in
zip
(
cont
,
context
):
generated_text
:
str
=
output
.
outputs
[
0
].
text
generated_text
:
str
=
output
.
outputs
[
0
].
text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
generated_text
=
postprocess_generated_text
(
generated_text
=
postprocess_generated_text
(
...
@@ -642,7 +644,7 @@ class VLLM(TemplateLM):
...
@@ -642,7 +644,7 @@ class VLLM(TemplateLM):
)
)
res
.
append
(
generated_text
)
res
.
append
(
generated_text
)
self
.
cache_hook
.
add_partial
(
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
context
,
gen_kwargs
),
generated_text
"generate_until"
,
(
context
_
,
gen_kwargs
),
generated_text
)
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
...
@@ -652,9 +654,9 @@ class VLLM(TemplateLM):
...
@@ -652,9 +654,9 @@ class VLLM(TemplateLM):
def
_loglikelihood_tokens
(
def
_loglikelihood_tokens
(
self
,
self
,
requests
:
L
ist
[
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]],
requests
:
l
ist
[
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]],
disable_tqdm
:
bool
=
False
,
disable_tqdm
:
bool
=
False
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
res
=
[]
res
=
[]
def
_collate
(
x
):
def
_collate
(
x
):
...
@@ -675,7 +677,7 @@ class VLLM(TemplateLM):
...
@@ -675,7 +677,7 @@ class VLLM(TemplateLM):
for
chunk
in
chunks
:
for
chunk
in
chunks
:
inputs
=
[]
inputs
=
[]
ctxlens
=
[]
ctxlens
=
[]
for
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
for
_
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
if
(
if
(
full_length
:
=
len
(
context_enc
+
continuation_enc
)
full_length
:
=
len
(
context_enc
+
continuation_enc
)
)
>
self
.
max_length
:
)
>
self
.
max_length
:
...
@@ -713,7 +715,7 @@ class VLLM(TemplateLM):
...
@@ -713,7 +715,7 @@ class VLLM(TemplateLM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
@
staticmethod
@
staticmethod
def
_parse_logprobs
(
tokens
:
L
ist
,
outputs
,
ctxlen
:
int
)
->
T
uple
[
float
,
bool
]:
def
_parse_logprobs
(
tokens
:
l
ist
,
outputs
,
ctxlen
:
int
)
->
t
uple
[
float
,
bool
]:
"""Process logprobs and tokens.
"""Process logprobs and tokens.
:param tokens: list
:param tokens: list
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment