Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0230356c
Unverified
Commit
0230356c
authored
Nov 30, 2024
by
Baber Abbasi
Committed by
GitHub
Nov 30, 2024
Browse files
make utility function to handle `until` (#2518)
* make utility function to handle `until` * fix text
parent
9169899b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
82 additions
and
74 deletions
+82
-74
lm_eval/models/anthropic_llms.py
lm_eval/models/anthropic_llms.py
+8
-3
lm_eval/models/api_models.py
lm_eval/models/api_models.py
+19
-0
lm_eval/models/hf_vlms.py
lm_eval/models/hf_vlms.py
+4
-16
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+4
-15
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+7
-3
lm_eval/models/utils.py
lm_eval/models/utils.py
+18
-0
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+10
-17
lm_eval/models/vllm_vlms.py
lm_eval/models/vllm_vlms.py
+9
-17
tests/models/test_api.py
tests/models/test_api.py
+3
-3
No files found.
lm_eval/models/anthropic_llms.py
View file @
0230356c
...
...
@@ -8,7 +8,7 @@ from lm_eval import utils
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.openai_completions
import
LocalCompletionsAPI
from
lm_eval.models.utils
import
retry_on_specific_exceptions
from
lm_eval.models.utils
import
handle_stop_sequences
,
retry_on_specific_exceptions
eval_logger
=
utils
.
eval_logger
...
...
@@ -311,7 +311,12 @@ class AnthropicChat(LocalCompletionsAPI):
}
def
_create_payload
(
self
,
messages
:
List
[
Dict
],
generate
=
True
,
gen_kwargs
:
dict
=
None
,
**
kwargs
self
,
messages
:
List
[
Dict
],
generate
=
True
,
gen_kwargs
:
dict
=
None
,
eos
=
"
\n\n
Human:"
,
**
kwargs
,
)
->
dict
:
system
=
(
messages
[
0
].
get
(
"content"
)
if
messages
[
0
].
get
(
"role"
)
==
"system"
else
None
...
...
@@ -321,7 +326,7 @@ class AnthropicChat(LocalCompletionsAPI):
gen_kwargs
.
pop
(
"do_sample"
,
False
)
max_tokens
=
gen_kwargs
.
pop
(
"max_gen_toks"
,
self
.
_max_gen_toks
)
temperature
=
gen_kwargs
.
pop
(
"temperature"
,
0
)
stop
=
gen_kwargs
.
pop
(
"until"
,
[
"
\n\n
Human:"
])
stop
=
handle_stop_sequences
(
gen_kwargs
.
pop
(
"until"
,
[
"
\n\n
Human:"
])
,
eos
=
eos
)
if
not
isinstance
(
stop
,
list
):
stop
=
[
stop
]
out
=
{
...
...
lm_eval/models/api_models.py
View file @
0230356c
...
...
@@ -80,6 +80,7 @@ class TemplateAPI(TemplateLM):
revision
:
Optional
[
str
]
=
"main"
,
use_fast_tokenizer
:
bool
=
True
,
verify_certificate
:
bool
=
True
,
eos_string
:
str
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
()
...
...
@@ -124,6 +125,7 @@ class TemplateAPI(TemplateLM):
self
.
tokenized_requests
=
tokenized_requests
self
.
max_retries
=
int
(
max_retries
)
self
.
verify_certificate
=
verify_certificate
self
.
_eos_string
=
eos_string
eval_logger
.
info
(
f
"Using tokenizer
{
self
.
tokenizer_backend
}
"
)
if
self
.
tokenizer_backend
is
None
:
...
...
@@ -176,6 +178,7 @@ class TemplateAPI(TemplateLM):
generate
:
bool
=
True
,
gen_kwargs
:
Optional
[
dict
]
=
None
,
seed
:
int
=
1234
,
eos
:
str
=
None
,
**
kwargs
,
)
->
dict
:
"""This method is responsible for creating the json payload that will be sent to the API."""
...
...
@@ -268,6 +271,21 @@ class TemplateAPI(TemplateLM):
elif
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
eot_token
@
cached_property
def
eos_string
(
self
)
->
Optional
[
str
]:
if
self
.
_eos_string
:
return
self
.
_eos_string
elif
self
.
tokenizer
is
not
None
:
if
self
.
tokenizer_backend
==
"huggingface"
:
return
self
.
tokenizer
.
eos_token
elif
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
decode
([
self
.
tokenizer
.
eot_token
])
else
:
eval_logger
.
warning
(
"Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args."
)
return
None
@
cached_property
def
prefix_token_id
(
self
)
->
Optional
[
int
]:
if
self
.
tokenizer
is
None
:
...
...
@@ -343,6 +361,7 @@ class TemplateAPI(TemplateLM):
generate
=
generate
,
gen_kwargs
=
gen_kwargs
,
seed
=
self
.
_seed
,
eos
=
self
.
eos_string
,
**
kwargs
,
),
headers
=
self
.
header
,
...
...
lm_eval/models/hf_vlms.py
View file @
0230356c
...
...
@@ -14,6 +14,7 @@ from lm_eval.models.huggingface import HFLM
from
lm_eval.models.utils
import
(
Collator
,
flatten_image_list
,
handle_stop_sequences
,
pad_and_concat
,
replace_placeholders
,
stop_sequences_criteria
,
...
...
@@ -629,7 +630,7 @@ class HFMultimodalLM(HFLM):
chunks
=
re_ords
.
get_batched
(
n
=
self
.
batch_size
,
batch_fn
=
None
)
### Up to here: was identical to non-multimodal HFLM generate_until ###
eos
=
self
.
tok_decode
(
self
.
eot_token_id
,
skip_special_tokens
=
False
)
for
chunk
in
chunks
:
contexts
,
all_gen_kwargs
,
aux_arguments
=
zip
(
*
chunk
)
...
...
@@ -646,27 +647,14 @@ class HFMultimodalLM(HFLM):
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs
=
all_gen_kwargs
[
0
]
# unpack our keyword arguments.
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
until
=
kwargs
.
pop
(
"until"
)
if
isinstance
(
until
,
str
):
until
=
[
until
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
# add EOS token to stop sequences
until
=
handle_stop_sequences
(
kwargs
.
pop
(
"until"
,
None
),
eos
=
eos
)
else
:
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
# add EOS token to stop sequences
eos
=
self
.
tok_decode
(
self
.
eot_token_id
,
skip_special_tokens
=
False
)
if
not
until
:
until
=
[
eos
]
else
:
until
.
append
(
eos
)
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
...
...
lm_eval/models/huggingface.py
View file @
0230356c
...
...
@@ -33,6 +33,7 @@ from lm_eval.models.utils import (
clear_torch_cache
,
configure_pad_token
,
get_dtype
,
handle_stop_sequences
,
pad_and_concat
,
stop_sequences_criteria
,
)
...
...
@@ -1255,33 +1256,21 @@ class HFLM(TemplateLM):
group_fn
=
lambda
x
:
x
[
1
],
)
chunks
=
re_ords
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
)
eos
=
self
.
tok_decode
(
self
.
eot_token_id
,
skip_special_tokens
=
False
)
for
chunk
in
chunks
:
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs
=
all_gen_kwargs
[
0
]
# unpack our keyword arguments.
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
until
=
kwargs
.
pop
(
"until"
)
if
isinstance
(
until
,
str
):
until
=
[
until
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
# add EOS token to stop sequences
until
=
handle_stop_sequences
(
kwargs
.
pop
(
"until"
,
None
),
eos
=
eos
)
else
:
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
# add EOS token to stop sequences
eos
=
self
.
tok_decode
(
self
.
eot_token_id
,
skip_special_tokens
=
False
)
if
not
until
:
until
=
[
eos
]
else
:
until
.
append
(
eos
)
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
...
...
lm_eval/models/openai_completions.py
View file @
0230356c
...
...
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.api_models
import
TemplateAPI
from
lm_eval.models.utils
import
handle_stop_sequences
from
lm_eval.utils
import
eval_logger
...
...
@@ -25,6 +26,7 @@ class LocalCompletionsAPI(TemplateAPI):
generate
=
False
,
gen_kwargs
:
Optional
[
dict
]
=
None
,
seed
:
int
=
1234
,
eos
=
None
,
**
kwargs
,
)
->
dict
:
if
generate
:
...
...
@@ -34,7 +36,7 @@ class LocalCompletionsAPI(TemplateAPI):
else
:
max_tokens
=
gen_kwargs
.
pop
(
"max_gen_toks"
,
self
.
_max_gen_toks
)
temperature
=
gen_kwargs
.
pop
(
"temperature"
,
0
)
stop
=
gen_kwargs
.
pop
(
"until"
,
[
"<|endoftext|>"
]
)
stop
=
handle_stop_sequences
(
gen_kwargs
.
pop
(
"until"
,
None
),
eos
)
return
{
"prompt"
:
messages
,
"model"
:
self
.
model
,
...
...
@@ -124,6 +126,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
generate
=
False
,
gen_kwargs
:
dict
=
None
,
seed
=
1234
,
eos
=
None
,
**
kwargs
,
)
->
dict
:
assert
(
...
...
@@ -135,7 +138,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
else
:
max_tokens
=
gen_kwargs
.
pop
(
"max_gen_toks"
,
self
.
_max_gen_toks
)
temperature
=
gen_kwargs
.
pop
(
"temperature"
,
0
)
stop
=
gen_kwargs
.
pop
(
"until"
,
[
"<|endoftext|>"
]
)
stop
=
handle_stop_sequences
(
gen_kwargs
.
pop
(
"until"
,
None
),
eos
)
if
not
isinstance
(
stop
,
(
list
,
tuple
)):
stop
=
[
stop
]
return
{
...
...
@@ -252,6 +255,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
generate
=
False
,
gen_kwargs
:
dict
=
None
,
seed
=
1234
,
eos
=
"<|endoftext|>"
,
**
kwargs
,
)
->
dict
:
assert
(
...
...
@@ -263,7 +267,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
else
:
max_tokens
=
gen_kwargs
.
pop
(
"max_gen_toks"
,
self
.
_max_gen_toks
)
temperature
=
gen_kwargs
.
pop
(
"temperature"
,
0
)
stop
=
gen_kwargs
.
pop
(
"until"
,
[
"<|endoftext|>"
])
stop
=
handle_stop_sequences
(
gen_kwargs
.
pop
(
"until"
,
[
"<|endoftext|>"
])
,
eos
)
if
not
isinstance
(
stop
,
(
list
,
tuple
)):
stop
=
[
stop
]
output
=
{
...
...
lm_eval/models/utils.py
View file @
0230356c
...
...
@@ -709,3 +709,21 @@ def flatten_image_list(images: List[List]):
:return: a list of PIL images, via concatenating all the sub-lists in order.
"""
return
[
image
for
image_list
in
images
for
image
in
image_list
]
def
handle_stop_sequences
(
until
:
Union
[
str
,
List
[
str
],
None
],
eos
:
Optional
[
str
]
)
->
List
[
str
]:
"""Ensures that the `until` parameter is a list of stop sequences and includes the EOS token."""
if
isinstance
(
until
,
str
):
until
=
[
until
]
elif
until
is
None
:
until
=
[]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
if
eos
is
not
None
and
eos
not
in
until
:
until
.
append
(
eos
)
return
until
lm_eval/models/vllm_causallms.py
View file @
0230356c
...
...
@@ -10,7 +10,12 @@ from tqdm import tqdm
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.model
import
TemplateLM
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.utils
import
Collator
,
configure_pad_token
,
undistribute
from
lm_eval.models.utils
import
(
Collator
,
configure_pad_token
,
handle_stop_sequences
,
undistribute
,
)
from
lm_eval.utils
import
(
eval_logger
,
get_rolling_token_windows
,
...
...
@@ -346,6 +351,7 @@ class VLLM(TemplateLM):
desc
=
"Running generate_until requests"
,
)
# for each different set of kwargs, we execute all requests, by batch.
eos
=
self
.
tokenizer
.
decode
(
self
.
eot_token_id
)
for
chunk
in
chunks
:
context_and_encoding
,
all_gen_kwargs
=
zip
(
*
chunk
)
context
,
context_encoding
=
zip
(
*
context_and_encoding
)
...
...
@@ -353,27 +359,14 @@ class VLLM(TemplateLM):
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs
=
all_gen_kwargs
[
0
]
# unpack our keyword arguments.
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
until
=
kwargs
.
pop
(
"until"
)
if
isinstance
(
until
,
str
):
until
=
[
until
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
# add EOS token to stop sequences
until
=
handle_stop_sequences
(
kwargs
.
pop
(
"until"
,
None
),
eos
=
eos
)
else
:
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
gen_kwargs
}
"
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
# add EOS token to stop sequences
eos
=
self
.
tokenizer
.
decode
(
self
.
eot_token_id
)
if
not
until
:
until
=
[
eos
]
else
:
until
.
append
(
eos
)
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
...
...
lm_eval/models/vllm_vlms.py
View file @
0230356c
...
...
@@ -7,7 +7,12 @@ from tqdm import tqdm
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.utils
import
Collator
,
replace_placeholders
,
undistribute
from
lm_eval.models.utils
import
(
Collator
,
handle_stop_sequences
,
replace_placeholders
,
undistribute
,
)
from
lm_eval.models.vllm_causallms
import
VLLM
from
lm_eval.utils
import
eval_logger
...
...
@@ -225,7 +230,7 @@ class VLLM_VLM(VLLM):
group_fn
=
lambda
x
:
x
[
1
],
)
chunks
=
re_ords
.
get_batched
(
n
=
self
.
batch_size
,
batch_fn
=
None
)
eos
=
self
.
tokenizer
.
decode
(
self
.
eot_token_id
)
for
chunk
in
chunks
:
contexts
,
all_gen_kwargs
,
aux_arguments
=
zip
(
*
chunk
)
...
...
@@ -241,27 +246,14 @@ class VLLM_VLM(VLLM):
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs
=
all_gen_kwargs
[
0
]
# unpack our keyword arguments.
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
until
=
kwargs
.
pop
(
"until"
)
if
isinstance
(
until
,
str
):
until
=
[
until
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
# add EOS token to stop sequences
until
=
handle_stop_sequences
(
kwargs
.
pop
(
"until"
,
None
),
eos
=
eos
)
else
:
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
# add EOS token to stop sequences
eos
=
self
.
tokenizer
.
decode
(
self
.
eot_token_id
)
if
not
until
:
until
=
[
eos
]
else
:
until
.
append
(
eos
)
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
...
...
tests/models/test_api.py
View file @
0230356c
...
...
@@ -63,13 +63,13 @@ def test_create_payload_loglikelihood(api):
(
[
"Hello, how are"
],
True
,
{
"max_gen_toks"
:
100
,
"temperature"
:
0.7
},
{
"max_gen_toks"
:
100
,
"temperature"
:
0.7
,
"until"
:
[
"hi"
]
},
{
"prompt"
:
"Hello, how are"
,
"model"
:
"gpt-3.5-turbo"
,
"max_tokens"
:
100
,
"temperature"
:
0.7
,
"stop"
:
[
"
<|endoftext|>
"
],
"stop"
:
[
"
hi
"
],
"seed"
:
1234
,
},
),
...
...
@@ -82,7 +82,7 @@ def test_create_payload_loglikelihood(api):
"model"
:
"gpt-3.5-turbo"
,
"max_tokens"
:
256
,
"temperature"
:
0
,
"stop"
:
[
"<|endoftext|>"
],
"stop"
:
[],
"seed"
:
1234
,
},
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment