Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
03f0e80e
Commit
03f0e80e
authored
Oct 15, 2025
by
Baber
Browse files
fix bos token handling
parent
d701d50f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
60 deletions
+66
-60
lm_eval/api/model.py
lm_eval/api/model.py
+7
-8
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+1
-1
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+58
-51
No files found.
lm_eval/api/model.py
View file @
03f0e80e
...
@@ -324,6 +324,7 @@ class TemplateLM(LM):
...
@@ -324,6 +324,7 @@ class TemplateLM(LM):
"""
"""
tokenizer
=
None
tokenizer
=
None
backend
=
"causal"
@
property
@
property
@
abc
.
abstractmethod
@
abc
.
abstractmethod
...
@@ -378,24 +379,22 @@ class TemplateLM(LM):
...
@@ -378,24 +379,22 @@ class TemplateLM(LM):
handle empty context (see loglikelihood method).
handle empty context (see loglikelihood method).
"""
"""
assert
context
,
"Context cannot be empty!"
assert
context
,
"Context cannot be empty!"
import
transformers
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
if
n_spaces
>
0
:
if
n_spaces
>
0
:
continuation
=
context
[
-
n_spaces
:]
+
continuation
continuation
=
context
[
-
n_spaces
:]
+
continuation
context
=
context
[:
-
n_spaces
]
context
=
context
[:
-
n_spaces
]
model_class
=
getattr
(
self
,
"AUTO_MODEL_CLASS"
,
None
)
if
self
.
backend
==
"causal"
:
if
model_class
==
transformers
.
AutoModelForSeq2SeqLM
:
context_enc
=
self
.
tok_encode
(
context
)
continuation_enc
=
self
.
tok_encode
(
continuation
,
add_special_tokens
=
False
)
else
:
whole_enc
=
self
.
tok_encode
(
context
+
continuation
)
whole_enc
=
self
.
tok_encode
(
context
+
continuation
)
context_enc
=
self
.
tok_encode
(
context
)
context_enc
=
self
.
tok_encode
(
context
)
context_enc_len
=
len
(
context_enc
)
context_enc_len
=
len
(
context_enc
)
continuation_enc
=
whole_enc
[
context_enc_len
:]
continuation_enc
=
whole_enc
[
context_enc_len
:]
else
:
# for SEQ2SEQ case we need to encode separately
context_enc
=
self
.
tok_encode
(
context
)
continuation_enc
=
self
.
tok_encode
(
continuation
,
add_special_tokens
=
False
)
return
context_enc
,
continuation_enc
return
context_enc
,
continuation_enc
...
@@ -433,7 +432,7 @@ class TemplateLM(LM):
...
@@ -433,7 +432,7 @@ class TemplateLM(LM):
continuation_enc
=
self
.
tok_encode
(
continuation_enc
=
self
.
tok_encode
(
continuation
,
add_special_tokens
=
False
continuation
,
add_special_tokens
=
False
)
)
# BOS or EOS as context
# BOS or EOS as context
: handle when context is empty -> (context + continuation) -> (BOS + continuation
context_enc
,
continuation_enc
=
(
context_enc
,
continuation_enc
=
(
([
self
.
prefix_token_id
],
continuation_enc
)
([
self
.
prefix_token_id
],
continuation_enc
)
if
self
.
prefix_token_id
!=
continuation_enc
[
0
]
if
self
.
prefix_token_id
!=
continuation_enc
[
0
]
...
...
lm_eval/models/huggingface.py
View file @
03f0e80e
...
@@ -258,7 +258,7 @@ class HFLM(TemplateLM):
...
@@ -258,7 +258,7 @@ class HFLM(TemplateLM):
else
{}
else
{}
)
)
self
.
add_bos_token
=
add_bos_token
if
add_bos_token
is
not
None
else
None
self
.
add_bos_token
=
add_bos_token
self
.
_max_length
=
max_length
self
.
_max_length
=
max_length
self
.
pretrained
=
pretrained
self
.
pretrained
=
pretrained
...
...
lm_eval/models/vllm_causallms.py
View file @
03f0e80e
from
__future__
import
annotations
import
copy
import
copy
import
gc
import
gc
import
logging
import
logging
...
@@ -7,7 +9,7 @@ from importlib.util import find_spec
...
@@ -7,7 +9,7 @@ from importlib.util import find_spec
from
multiprocessing
import
Process
,
Queue
from
multiprocessing
import
Process
,
Queue
from
queue
import
Empty
from
queue
import
Empty
from
time
import
sleep
from
time
import
sleep
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
import
jinja2
import
jinja2
from
more_itertools
import
distribute
from
more_itertools
import
distribute
...
@@ -50,10 +52,10 @@ eval_logger = logging.getLogger(__name__)
...
@@ -50,10 +52,10 @@ eval_logger = logging.getLogger(__name__)
def
_vllm_mp_worker
(
def
_vllm_mp_worker
(
model_args
:
dict
,
model_args
:
dict
,
sampling_params
:
list
[
"
SamplingParams
"
],
sampling_params
:
list
[
SamplingParams
],
requests
:
list
[
list
[
int
]],
requests
:
list
[
list
[
int
]],
lora_request
:
"
LoRARequest
"
,
lora_request
:
LoRARequest
,
result_queue
:
"
Queue
"
,
result_queue
:
Queue
,
dp_size
:
int
,
dp_size
:
int
,
local_dp_rank
:
int
,
local_dp_rank
:
int
,
dp_master_port
:
int
,
dp_master_port
:
int
,
...
@@ -113,18 +115,18 @@ class VLLM(TemplateLM):
...
@@ -113,18 +115,18 @@ class VLLM(TemplateLM):
self
,
self
,
pretrained
:
str
,
pretrained
:
str
,
dtype
:
Literal
[
"float16"
,
"bfloat16"
,
"float32"
,
"auto"
]
=
"auto"
,
dtype
:
Literal
[
"float16"
,
"bfloat16"
,
"float32"
,
"auto"
]
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
revision
:
str
|
None
=
None
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
trust_remote_code
:
bool
|
None
=
False
,
tokenizer
:
Optional
[
str
]
=
None
,
tokenizer
:
str
|
None
=
None
,
tokenizer_mode
:
Literal
[
"auto"
,
"slow"
]
=
"auto"
,
tokenizer_mode
:
Literal
[
"auto"
,
"slow"
]
=
"auto"
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
tokenizer_revision
:
str
|
None
=
None
,
add_bos_token
:
Optional
[
bool
]
=
False
,
add_bos_token
:
bool
|
None
=
False
,
prefix_token_id
:
Optional
[
int
]
=
None
,
prefix_token_id
:
int
|
None
=
None
,
tensor_parallel_size
:
int
=
1
,
tensor_parallel_size
:
int
=
1
,
quantization
:
Optional
[
str
]
=
None
,
quantization
:
str
|
None
=
None
,
max_gen_toks
:
int
=
256
,
max_gen_toks
:
int
=
256
,
swap_space
:
int
=
4
,
swap_space
:
int
=
4
,
batch_size
:
Union
[
str
,
int
]
=
1
,
batch_size
:
str
|
int
=
1
,
max_batch_size
=
None
,
max_batch_size
=
None
,
max_length
:
int
=
None
,
max_length
:
int
=
None
,
max_model_len
:
int
=
None
,
max_model_len
:
int
=
None
,
...
@@ -134,9 +136,9 @@ class VLLM(TemplateLM):
...
@@ -134,9 +136,9 @@ class VLLM(TemplateLM):
lora_local_path
:
str
=
None
,
lora_local_path
:
str
=
None
,
# VLLM: enable thinking tags in the prompt.
# VLLM: enable thinking tags in the prompt.
enable_thinking
:
bool
=
True
,
enable_thinking
:
bool
=
True
,
chat_template_args
:
Optional
[
dict
]
=
None
,
chat_template_args
:
dict
|
None
=
None
,
# End marker for thinking tags - splits to get response after this token (if provided).
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token
:
Optional
[
str
]
=
None
,
think_end_token
:
str
|
None
=
None
,
max_lora_rank
:
int
=
16
,
max_lora_rank
:
int
=
16
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -195,11 +197,7 @@ class VLLM(TemplateLM):
...
@@ -195,11 +197,7 @@ class VLLM(TemplateLM):
self
.
batch_size
=
"auto"
self
.
batch_size
=
"auto"
eval_logger
.
info
(
"Manual batching is not compatible with data parallelism."
)
eval_logger
.
info
(
"Manual batching is not compatible with data parallelism."
)
if
"gemma"
in
pretrained
.
lower
():
self
.
add_bos_token
=
add_bos_token
add_bos_token
=
True
eval_logger
.
info
(
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
)
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
...
@@ -211,14 +209,17 @@ class VLLM(TemplateLM):
...
@@ -211,14 +209,17 @@ class VLLM(TemplateLM):
tokenizer_mode
=
tokenizer_mode
,
tokenizer_mode
=
tokenizer_mode
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
revision
=
tokenizer_revision
,
revision
=
tokenizer_revision
,
add_bos_token
=
add_bos_token
,
**
(
{
"add_bos_token"
:
self
.
add_bos_token
}
if
self
.
add_bos_token
is
not
None
else
{}
),
)
)
self
.
tokenizer
=
configure_pad_token
(
self
.
tokenizer
,
model_config
=
self
.
_config
)
self
.
tokenizer
=
configure_pad_token
(
self
.
tokenizer
,
model_config
=
self
.
_config
)
self
.
chat_template_args
=
chat_template_args
or
{}
self
.
chat_template_args
=
chat_template_args
or
{}
self
.
enable_thinking
=
self
.
chat_template_args
.
pop
(
self
.
enable_thinking
=
self
.
chat_template_args
.
pop
(
"enable_thinking"
,
enable_thinking
"enable_thinking"
,
enable_thinking
)
)
self
.
add_bos_token
=
add_bos_token
if
parse_version
(
version
(
"vllm"
))
>=
parse_version
(
"0.8.3"
):
if
parse_version
(
version
(
"vllm"
))
>=
parse_version
(
"0.8.3"
):
kwargs_resolve_hf_chat_template
=
{
kwargs_resolve_hf_chat_template
=
{
...
@@ -265,7 +266,7 @@ class VLLM(TemplateLM):
...
@@ -265,7 +266,7 @@ class VLLM(TemplateLM):
self
.
lora_request
=
None
self
.
lora_request
=
None
@
property
@
property
def
eot_token_id
(
self
):
def
eot_token_id
(
self
)
->
int
|
None
:
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return
self
.
tokenizer
.
eos_token_id
return
self
.
tokenizer
.
eos_token_id
...
@@ -300,7 +301,7 @@ class VLLM(TemplateLM):
...
@@ -300,7 +301,7 @@ class VLLM(TemplateLM):
return
self
.
_max_gen_toks
return
self
.
_max_gen_toks
def
apply_chat_template
(
def
apply_chat_template
(
self
,
chat_history
:
L
ist
[
D
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
self
,
chat_history
:
l
ist
[
d
ict
[
str
,
str
]],
add_generation_prompt
:
bool
=
True
)
->
str
:
)
->
str
:
"""
"""
Method to apply a chat template to a list of chat history between user and model.
Method to apply a chat template to a list of chat history between user and model.
...
@@ -337,18 +338,26 @@ class VLLM(TemplateLM):
...
@@ -337,18 +338,26 @@ class VLLM(TemplateLM):
def
tok_encode
(
def
tok_encode
(
self
,
self
,
string
:
Union
[
str
,
L
ist
[
str
]
]
,
string
:
str
|
l
ist
[
str
],
left_truncate_len
:
int
=
None
,
left_truncate_len
:
int
|
None
=
None
,
add_special_tokens
:
bool
=
Fals
e
,
add_special_tokens
:
bool
|
None
=
Non
e
,
truncation
:
bool
=
False
,
truncation
:
bool
=
False
,
)
->
Union
[
List
[
int
],
List
[
List
[
int
]]]:
)
->
list
[
int
]
|
list
[
list
[
int
]]:
if
not
add_special_tokens
:
add_special_kwargs
=
(
add_special_tokens
=
False
or
self
.
add_bos_token
{
"add_special_tokens"
:
add_special_tokens
or
self
.
add_bos_token
}
encoding
:
Union
[
List
[
List
[
int
]],
List
[
int
]]
=
self
.
tokenizer
(
if
(
add_special_tokens
is
not
None
or
self
.
add_bos_token
is
not
None
)
else
{}
)
# handle chat template
if
self
.
tokenizer
.
bos_token
and
(
string
[
0
]
if
isinstance
(
string
,
list
)
else
string
).
startswith
(
self
.
tokenizer
.
bos_token
):
add_special_kwargs
=
{
"add_special_tokens"
:
False
}
encoding
:
list
[
list
[
int
]]
|
list
[
int
]
=
self
.
tokenizer
(
string
,
string
,
add_special_tokens
=
add_special_tokens
,
truncation
=
truncation
,
truncation
=
truncation
,
return_attention_mask
=
False
,
return_attention_mask
=
False
,
**
add_special_kwargs
,
).
input_ids
).
input_ids
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
...
@@ -362,15 +371,15 @@ class VLLM(TemplateLM):
...
@@ -362,15 +371,15 @@ class VLLM(TemplateLM):
def
_model_generate
(
def
_model_generate
(
self
,
self
,
requests
:
L
ist
[
L
ist
[
int
]]
=
None
,
requests
:
l
ist
[
l
ist
[
int
]],
generate
:
bool
=
False
,
generate
:
bool
=
False
,
sampling_params
:
Union
[
L
ist
[
"
SamplingParams
"
],
"
SamplingParams
"
,
None
]
=
None
,
sampling_params
:
l
ist
[
SamplingParams
]
|
SamplingParams
|
None
=
None
,
):
):
if
not
generate
or
sampling_params
is
None
:
if
not
generate
or
sampling_params
is
None
:
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0
,
prompt_logprobs
=
1
,
max_tokens
=
1
,
detokenize
=
False
temperature
=
0
,
prompt_logprobs
=
1
,
max_tokens
=
1
,
detokenize
=
False
)
)
if
not
isinstance
(
sampling_params
,
L
ist
):
if
not
isinstance
(
sampling_params
,
l
ist
):
sampling_params
=
[
sampling_params
]
*
len
(
requests
)
sampling_params
=
[
sampling_params
]
*
len
(
requests
)
if
self
.
data_parallel_size
>
1
and
not
self
.
V1
:
if
self
.
data_parallel_size
>
1
and
not
self
.
V1
:
# vLLM hangs if resources are set in ray.remote
# vLLM hangs if resources are set in ray.remote
...
@@ -379,9 +388,9 @@ class VLLM(TemplateLM):
...
@@ -379,9 +388,9 @@ class VLLM(TemplateLM):
@
ray
.
remote
@
ray
.
remote
def
run_inference_one_model
(
def
run_inference_one_model
(
model_args
:
dict
,
model_args
:
dict
,
sampling_params
:
L
ist
[
"
SamplingParams
"
],
sampling_params
:
l
ist
[
SamplingParams
],
requests
:
L
ist
[
L
ist
[
int
]],
requests
:
l
ist
[
l
ist
[
int
]],
lora_request
:
"
LoRARequest
"
,
lora_request
:
LoRARequest
,
):
):
llm
=
LLM
(
**
model_args
)
llm
=
LLM
(
**
model_args
)
return
llm
.
generate
(
return
llm
.
generate
(
...
@@ -487,8 +496,8 @@ class VLLM(TemplateLM):
...
@@ -487,8 +496,8 @@ class VLLM(TemplateLM):
return
outputs
return
outputs
def
loglikelihood_rolling
(
def
loglikelihood_rolling
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
adaptive_batch_size
=
None
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
if
self
.
batch_size
==
"auto"
:
adaptive_batch_size
=
len
(
requests
)
adaptive_batch_size
=
len
(
requests
)
...
@@ -503,7 +512,7 @@ class VLLM(TemplateLM):
...
@@ -503,7 +512,7 @@ class VLLM(TemplateLM):
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
)
)
):
):
rolling_token_windows
:
L
ist
[
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]]
=
list
(
rolling_token_windows
:
l
ist
[
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]]
=
list
(
map
(
map
(
make_disjoint_window
,
make_disjoint_window
,
get_rolling_token_windows
(
get_rolling_token_windows
(
...
@@ -556,16 +565,14 @@ class VLLM(TemplateLM):
...
@@ -556,16 +565,14 @@ class VLLM(TemplateLM):
return
loglikelihoods
return
loglikelihoods
def
generate_until
(
def
generate_until
(
self
,
requests
:
L
ist
[
Instance
],
disable_tqdm
:
bool
=
False
self
,
requests
:
l
ist
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
str
]:
)
->
l
ist
[
str
]:
res
=
[]
res
=
[]
# batch tokenize contexts
# batch tokenize contexts
context
,
all_gen_kwargs
=
zip
(
*
(
req
.
args
for
req
in
requests
))
context
,
all_gen_kwargs
=
zip
(
*
(
req
.
args
for
req
in
requests
))
context_encoding
:
List
[
List
[
int
]]
=
self
.
tok_encode
(
context_encoding
=
self
.
tok_encode
(
context
)
context
,
add_special_tokens
=
self
.
add_bos_token
reqs
=
[
)
requests
=
[
((
a
,
b
),
c
)
for
a
,
b
,
c
in
zip
(
context
,
context_encoding
,
all_gen_kwargs
)
((
a
,
b
),
c
)
for
a
,
b
,
c
in
zip
(
context
,
context_encoding
,
all_gen_kwargs
)
]
]
...
@@ -579,7 +586,7 @@ class VLLM(TemplateLM):
...
@@ -579,7 +586,7 @@ class VLLM(TemplateLM):
return
-
len
(
_requests
[
0
][
1
]),
_requests
[
0
][
0
]
return
-
len
(
_requests
[
0
][
1
]),
_requests
[
0
][
0
]
re_ords
=
Collator
(
re_ords
=
Collator
(
req
uest
s
,
reqs
,
_collate_gen
,
_collate_gen
,
group_by
=
None
,
group_by
=
None
,
)
)
...
@@ -588,7 +595,7 @@ class VLLM(TemplateLM):
...
@@ -588,7 +595,7 @@ class VLLM(TemplateLM):
)
)
pbar
=
tqdm
(
pbar
=
tqdm
(
total
=
len
(
req
uest
s
),
total
=
len
(
reqs
),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
desc
=
"Running generate_until requests"
,
desc
=
"Running generate_until requests"
,
)
)
...
@@ -656,9 +663,9 @@ class VLLM(TemplateLM):
...
@@ -656,9 +663,9 @@ class VLLM(TemplateLM):
def
_loglikelihood_tokens
(
def
_loglikelihood_tokens
(
self
,
self
,
requests
:
L
ist
[
T
uple
[
T
uple
[
str
,
str
],
L
ist
[
int
],
L
ist
[
int
]]],
requests
:
l
ist
[
t
uple
[
t
uple
[
str
,
str
],
l
ist
[
int
],
l
ist
[
int
]]],
disable_tqdm
:
bool
=
False
,
disable_tqdm
:
bool
=
False
,
)
->
L
ist
[
T
uple
[
float
,
bool
]]:
)
->
l
ist
[
t
uple
[
float
,
bool
]]:
res
=
[]
res
=
[]
def
_collate
(
x
):
def
_collate
(
x
):
...
@@ -717,7 +724,7 @@ class VLLM(TemplateLM):
...
@@ -717,7 +724,7 @@ class VLLM(TemplateLM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
@
staticmethod
@
staticmethod
def
_parse_logprobs
(
tokens
:
L
ist
,
outputs
,
ctxlen
:
int
)
->
T
uple
[
float
,
bool
]:
def
_parse_logprobs
(
tokens
:
l
ist
,
outputs
,
ctxlen
:
int
)
->
t
uple
[
float
,
bool
]:
"""Process logprobs and tokens.
"""Process logprobs and tokens.
:param tokens: list
:param tokens: list
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment