Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0230356c
Unverified
Commit
0230356c
authored
Nov 30, 2024
by
Baber Abbasi
Committed by
GitHub
Nov 30, 2024
Browse files
make utility function to handle `until` (#2518)
* make utility function to handle `until` * fix text
parent
9169899b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
82 additions
and
74 deletions
+82
-74
lm_eval/models/anthropic_llms.py
lm_eval/models/anthropic_llms.py
+8
-3
lm_eval/models/api_models.py
lm_eval/models/api_models.py
+19
-0
lm_eval/models/hf_vlms.py
lm_eval/models/hf_vlms.py
+4
-16
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+4
-15
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+7
-3
lm_eval/models/utils.py
lm_eval/models/utils.py
+18
-0
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+10
-17
lm_eval/models/vllm_vlms.py
lm_eval/models/vllm_vlms.py
+9
-17
tests/models/test_api.py
tests/models/test_api.py
+3
-3
No files found.
lm_eval/models/anthropic_llms.py
View file @
0230356c
...
@@ -8,7 +8,7 @@ from lm_eval import utils
...
@@ -8,7 +8,7 @@ from lm_eval import utils
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.openai_completions
import
LocalCompletionsAPI
from
lm_eval.models.openai_completions
import
LocalCompletionsAPI
from
lm_eval.models.utils
import
retry_on_specific_exceptions
from
lm_eval.models.utils
import
handle_stop_sequences
,
retry_on_specific_exceptions
eval_logger
=
utils
.
eval_logger
eval_logger
=
utils
.
eval_logger
...
@@ -311,7 +311,12 @@ class AnthropicChat(LocalCompletionsAPI):
...
@@ -311,7 +311,12 @@ class AnthropicChat(LocalCompletionsAPI):
}
}
def
_create_payload
(
def
_create_payload
(
self
,
messages
:
List
[
Dict
],
generate
=
True
,
gen_kwargs
:
dict
=
None
,
**
kwargs
self
,
messages
:
List
[
Dict
],
generate
=
True
,
gen_kwargs
:
dict
=
None
,
eos
=
"
\n\n
Human:"
,
**
kwargs
,
)
->
dict
:
)
->
dict
:
system
=
(
system
=
(
messages
[
0
].
get
(
"content"
)
if
messages
[
0
].
get
(
"role"
)
==
"system"
else
None
messages
[
0
].
get
(
"content"
)
if
messages
[
0
].
get
(
"role"
)
==
"system"
else
None
...
@@ -321,7 +326,7 @@ class AnthropicChat(LocalCompletionsAPI):
...
@@ -321,7 +326,7 @@ class AnthropicChat(LocalCompletionsAPI):
gen_kwargs
.
pop
(
"do_sample"
,
False
)
gen_kwargs
.
pop
(
"do_sample"
,
False
)
max_tokens
=
gen_kwargs
.
pop
(
"max_gen_toks"
,
self
.
_max_gen_toks
)
max_tokens
=
gen_kwargs
.
pop
(
"max_gen_toks"
,
self
.
_max_gen_toks
)
temperature
=
gen_kwargs
.
pop
(
"temperature"
,
0
)
temperature
=
gen_kwargs
.
pop
(
"temperature"
,
0
)
stop
=
gen_kwargs
.
pop
(
"until"
,
[
"
\n\n
Human:"
])
stop
=
handle_stop_sequences
(
gen_kwargs
.
pop
(
"until"
,
[
"
\n\n
Human:"
])
,
eos
=
eos
)
if
not
isinstance
(
stop
,
list
):
if
not
isinstance
(
stop
,
list
):
stop
=
[
stop
]
stop
=
[
stop
]
out
=
{
out
=
{
...
...
lm_eval/models/api_models.py
View file @
0230356c
...
@@ -80,6 +80,7 @@ class TemplateAPI(TemplateLM):
...
@@ -80,6 +80,7 @@ class TemplateAPI(TemplateLM):
revision
:
Optional
[
str
]
=
"main"
,
revision
:
Optional
[
str
]
=
"main"
,
use_fast_tokenizer
:
bool
=
True
,
use_fast_tokenizer
:
bool
=
True
,
verify_certificate
:
bool
=
True
,
verify_certificate
:
bool
=
True
,
eos_string
:
str
=
None
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -124,6 +125,7 @@ class TemplateAPI(TemplateLM):
...
@@ -124,6 +125,7 @@ class TemplateAPI(TemplateLM):
self
.
tokenized_requests
=
tokenized_requests
self
.
tokenized_requests
=
tokenized_requests
self
.
max_retries
=
int
(
max_retries
)
self
.
max_retries
=
int
(
max_retries
)
self
.
verify_certificate
=
verify_certificate
self
.
verify_certificate
=
verify_certificate
self
.
_eos_string
=
eos_string
eval_logger
.
info
(
f
"Using tokenizer
{
self
.
tokenizer_backend
}
"
)
eval_logger
.
info
(
f
"Using tokenizer
{
self
.
tokenizer_backend
}
"
)
if
self
.
tokenizer_backend
is
None
:
if
self
.
tokenizer_backend
is
None
:
...
@@ -176,6 +178,7 @@ class TemplateAPI(TemplateLM):
...
@@ -176,6 +178,7 @@ class TemplateAPI(TemplateLM):
generate
:
bool
=
True
,
generate
:
bool
=
True
,
gen_kwargs
:
Optional
[
dict
]
=
None
,
gen_kwargs
:
Optional
[
dict
]
=
None
,
seed
:
int
=
1234
,
seed
:
int
=
1234
,
eos
:
str
=
None
,
**
kwargs
,
**
kwargs
,
)
->
dict
:
)
->
dict
:
"""This method is responsible for creating the json payload that will be sent to the API."""
"""This method is responsible for creating the json payload that will be sent to the API."""
...
@@ -268,6 +271,21 @@ class TemplateAPI(TemplateLM):
...
@@ -268,6 +271,21 @@ class TemplateAPI(TemplateLM):
elif
self
.
tokenizer_backend
==
"tiktoken"
:
elif
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
eot_token
return
self
.
tokenizer
.
eot_token
@
cached_property
def
eos_string
(
self
)
->
Optional
[
str
]:
if
self
.
_eos_string
:
return
self
.
_eos_string
elif
self
.
tokenizer
is
not
None
:
if
self
.
tokenizer_backend
==
"huggingface"
:
return
self
.
tokenizer
.
eos_token
elif
self
.
tokenizer_backend
==
"tiktoken"
:
return
self
.
tokenizer
.
decode
([
self
.
tokenizer
.
eot_token
])
else
:
eval_logger
.
warning
(
"Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args."
)
return
None
@
cached_property
@
cached_property
def
prefix_token_id
(
self
)
->
Optional
[
int
]:
def
prefix_token_id
(
self
)
->
Optional
[
int
]:
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
...
@@ -343,6 +361,7 @@ class TemplateAPI(TemplateLM):
...
@@ -343,6 +361,7 @@ class TemplateAPI(TemplateLM):
generate
=
generate
,
generate
=
generate
,
gen_kwargs
=
gen_kwargs
,
gen_kwargs
=
gen_kwargs
,
seed
=
self
.
_seed
,
seed
=
self
.
_seed
,
eos
=
self
.
eos_string
,
**
kwargs
,
**
kwargs
,
),
),
headers
=
self
.
header
,
headers
=
self
.
header
,
...
...
lm_eval/models/hf_vlms.py
View file @
0230356c
...
@@ -14,6 +14,7 @@ from lm_eval.models.huggingface import HFLM
...
@@ -14,6 +14,7 @@ from lm_eval.models.huggingface import HFLM
from
lm_eval.models.utils
import
(
from
lm_eval.models.utils
import
(
Collator
,
Collator
,
flatten_image_list
,
flatten_image_list
,
handle_stop_sequences
,
pad_and_concat
,
pad_and_concat
,
replace_placeholders
,
replace_placeholders
,
stop_sequences_criteria
,
stop_sequences_criteria
,
...
@@ -629,7 +630,7 @@ class HFMultimodalLM(HFLM):
...
@@ -629,7 +630,7 @@ class HFMultimodalLM(HFLM):
chunks
=
re_ords
.
get_batched
(
n
=
self
.
batch_size
,
batch_fn
=
None
)
chunks
=
re_ords
.
get_batched
(
n
=
self
.
batch_size
,
batch_fn
=
None
)
### Up to here: was identical to non-multimodal HFLM generate_until ###
### Up to here: was identical to non-multimodal HFLM generate_until ###
eos
=
self
.
tok_decode
(
self
.
eot_token_id
,
skip_special_tokens
=
False
)
for
chunk
in
chunks
:
for
chunk
in
chunks
:
contexts
,
all_gen_kwargs
,
aux_arguments
=
zip
(
*
chunk
)
contexts
,
all_gen_kwargs
,
aux_arguments
=
zip
(
*
chunk
)
...
@@ -646,27 +647,14 @@ class HFMultimodalLM(HFLM):
...
@@ -646,27 +647,14 @@ class HFMultimodalLM(HFLM):
# this is safe to assume because the `grouper` object ensures it.
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs
=
all_gen_kwargs
[
0
]
gen_kwargs
=
all_gen_kwargs
[
0
]
# unpack our keyword arguments.
# unpack our keyword arguments.
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
# add EOS token to stop sequences
until
=
kwargs
.
pop
(
"until"
)
until
=
handle_stop_sequences
(
kwargs
.
pop
(
"until"
,
None
),
eos
=
eos
)
if
isinstance
(
until
,
str
):
until
=
[
until
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
)
# add EOS token to stop sequences
eos
=
self
.
tok_decode
(
self
.
eot_token_id
,
skip_special_tokens
=
False
)
if
not
until
:
until
=
[
eos
]
else
:
until
.
append
(
eos
)
if
"max_gen_toks"
in
kwargs
.
keys
():
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
else
:
...
...
lm_eval/models/huggingface.py
View file @
0230356c
...
@@ -33,6 +33,7 @@ from lm_eval.models.utils import (
...
@@ -33,6 +33,7 @@ from lm_eval.models.utils import (
clear_torch_cache
,
clear_torch_cache
,
configure_pad_token
,
configure_pad_token
,
get_dtype
,
get_dtype
,
handle_stop_sequences
,
pad_and_concat
,
pad_and_concat
,
stop_sequences_criteria
,
stop_sequences_criteria
,
)
)
...
@@ -1255,33 +1256,21 @@ class HFLM(TemplateLM):
...
@@ -1255,33 +1256,21 @@ class HFLM(TemplateLM):
group_fn
=
lambda
x
:
x
[
1
],
group_fn
=
lambda
x
:
x
[
1
],
)
)
chunks
=
re_ords
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
)
chunks
=
re_ords
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
)
eos
=
self
.
tok_decode
(
self
.
eot_token_id
,
skip_special_tokens
=
False
)
for
chunk
in
chunks
:
for
chunk
in
chunks
:
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
# we assume all gen kwargs in the batch are the same
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs
=
all_gen_kwargs
[
0
]
gen_kwargs
=
all_gen_kwargs
[
0
]
# unpack our keyword arguments.
# unpack our keyword arguments.
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
# add EOS token to stop sequences
until
=
kwargs
.
pop
(
"until"
)
until
=
handle_stop_sequences
(
kwargs
.
pop
(
"until"
,
None
),
eos
=
eos
)
if
isinstance
(
until
,
str
):
until
=
[
until
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
)
# add EOS token to stop sequences
eos
=
self
.
tok_decode
(
self
.
eot_token_id
,
skip_special_tokens
=
False
)
if
not
until
:
until
=
[
eos
]
else
:
until
.
append
(
eos
)
if
"max_gen_toks"
in
kwargs
.
keys
():
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
else
:
...
...
lm_eval/models/openai_completions.py
View file @
0230356c
...
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
...
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.api_models
import
TemplateAPI
from
lm_eval.models.api_models
import
TemplateAPI
from
lm_eval.models.utils
import
handle_stop_sequences
from
lm_eval.utils
import
eval_logger
from
lm_eval.utils
import
eval_logger
...
@@ -25,6 +26,7 @@ class LocalCompletionsAPI(TemplateAPI):
...
@@ -25,6 +26,7 @@ class LocalCompletionsAPI(TemplateAPI):
generate
=
False
,
generate
=
False
,
gen_kwargs
:
Optional
[
dict
]
=
None
,
gen_kwargs
:
Optional
[
dict
]
=
None
,
seed
:
int
=
1234
,
seed
:
int
=
1234
,
eos
=
None
,
**
kwargs
,
**
kwargs
,
)
->
dict
:
)
->
dict
:
if
generate
:
if
generate
:
...
@@ -34,7 +36,7 @@ class LocalCompletionsAPI(TemplateAPI):
...
@@ -34,7 +36,7 @@ class LocalCompletionsAPI(TemplateAPI):
else
:
else
:
max_tokens
=
gen_kwargs
.
pop
(
"max_gen_toks"
,
self
.
_max_gen_toks
)
max_tokens
=
gen_kwargs
.
pop
(
"max_gen_toks"
,
self
.
_max_gen_toks
)
temperature
=
gen_kwargs
.
pop
(
"temperature"
,
0
)
temperature
=
gen_kwargs
.
pop
(
"temperature"
,
0
)
stop
=
gen_kwargs
.
pop
(
"until"
,
[
"<|endoftext|>"
]
)
stop
=
handle_stop_sequences
(
gen_kwargs
.
pop
(
"until"
,
None
),
eos
)
return
{
return
{
"prompt"
:
messages
,
"prompt"
:
messages
,
"model"
:
self
.
model
,
"model"
:
self
.
model
,
...
@@ -124,6 +126,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
...
@@ -124,6 +126,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
generate
=
False
,
generate
=
False
,
gen_kwargs
:
dict
=
None
,
gen_kwargs
:
dict
=
None
,
seed
=
1234
,
seed
=
1234
,
eos
=
None
,
**
kwargs
,
**
kwargs
,
)
->
dict
:
)
->
dict
:
assert
(
assert
(
...
@@ -135,7 +138,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
...
@@ -135,7 +138,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
else
:
else
:
max_tokens
=
gen_kwargs
.
pop
(
"max_gen_toks"
,
self
.
_max_gen_toks
)
max_tokens
=
gen_kwargs
.
pop
(
"max_gen_toks"
,
self
.
_max_gen_toks
)
temperature
=
gen_kwargs
.
pop
(
"temperature"
,
0
)
temperature
=
gen_kwargs
.
pop
(
"temperature"
,
0
)
stop
=
gen_kwargs
.
pop
(
"until"
,
[
"<|endoftext|>"
]
)
stop
=
handle_stop_sequences
(
gen_kwargs
.
pop
(
"until"
,
None
),
eos
)
if
not
isinstance
(
stop
,
(
list
,
tuple
)):
if
not
isinstance
(
stop
,
(
list
,
tuple
)):
stop
=
[
stop
]
stop
=
[
stop
]
return
{
return
{
...
@@ -252,6 +255,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
...
@@ -252,6 +255,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
generate
=
False
,
generate
=
False
,
gen_kwargs
:
dict
=
None
,
gen_kwargs
:
dict
=
None
,
seed
=
1234
,
seed
=
1234
,
eos
=
"<|endoftext|>"
,
**
kwargs
,
**
kwargs
,
)
->
dict
:
)
->
dict
:
assert
(
assert
(
...
@@ -263,7 +267,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
...
@@ -263,7 +267,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
else
:
else
:
max_tokens
=
gen_kwargs
.
pop
(
"max_gen_toks"
,
self
.
_max_gen_toks
)
max_tokens
=
gen_kwargs
.
pop
(
"max_gen_toks"
,
self
.
_max_gen_toks
)
temperature
=
gen_kwargs
.
pop
(
"temperature"
,
0
)
temperature
=
gen_kwargs
.
pop
(
"temperature"
,
0
)
stop
=
gen_kwargs
.
pop
(
"until"
,
[
"<|endoftext|>"
])
stop
=
handle_stop_sequences
(
gen_kwargs
.
pop
(
"until"
,
[
"<|endoftext|>"
])
,
eos
)
if
not
isinstance
(
stop
,
(
list
,
tuple
)):
if
not
isinstance
(
stop
,
(
list
,
tuple
)):
stop
=
[
stop
]
stop
=
[
stop
]
output
=
{
output
=
{
...
...
lm_eval/models/utils.py
View file @
0230356c
...
@@ -709,3 +709,21 @@ def flatten_image_list(images: List[List]):
...
@@ -709,3 +709,21 @@ def flatten_image_list(images: List[List]):
:return: a list of PIL images, via concatenating all the sub-lists in order.
:return: a list of PIL images, via concatenating all the sub-lists in order.
"""
"""
return
[
image
for
image_list
in
images
for
image
in
image_list
]
return
[
image
for
image_list
in
images
for
image
in
image_list
]
def
handle_stop_sequences
(
until
:
Union
[
str
,
List
[
str
],
None
],
eos
:
Optional
[
str
]
)
->
List
[
str
]:
"""Ensures that the `until` parameter is a list of stop sequences and includes the EOS token."""
if
isinstance
(
until
,
str
):
until
=
[
until
]
elif
until
is
None
:
until
=
[]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
if
eos
is
not
None
and
eos
not
in
until
:
until
.
append
(
eos
)
return
until
lm_eval/models/vllm_causallms.py
View file @
0230356c
...
@@ -10,7 +10,12 @@ from tqdm import tqdm
...
@@ -10,7 +10,12 @@ from tqdm import tqdm
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.model
import
TemplateLM
from
lm_eval.api.model
import
TemplateLM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.utils
import
Collator
,
configure_pad_token
,
undistribute
from
lm_eval.models.utils
import
(
Collator
,
configure_pad_token
,
handle_stop_sequences
,
undistribute
,
)
from
lm_eval.utils
import
(
from
lm_eval.utils
import
(
eval_logger
,
eval_logger
,
get_rolling_token_windows
,
get_rolling_token_windows
,
...
@@ -346,6 +351,7 @@ class VLLM(TemplateLM):
...
@@ -346,6 +351,7 @@ class VLLM(TemplateLM):
desc
=
"Running generate_until requests"
,
desc
=
"Running generate_until requests"
,
)
)
# for each different set of kwargs, we execute all requests, by batch.
# for each different set of kwargs, we execute all requests, by batch.
eos
=
self
.
tokenizer
.
decode
(
self
.
eot_token_id
)
for
chunk
in
chunks
:
for
chunk
in
chunks
:
context_and_encoding
,
all_gen_kwargs
=
zip
(
*
chunk
)
context_and_encoding
,
all_gen_kwargs
=
zip
(
*
chunk
)
context
,
context_encoding
=
zip
(
*
context_and_encoding
)
context
,
context_encoding
=
zip
(
*
context_and_encoding
)
...
@@ -353,27 +359,14 @@ class VLLM(TemplateLM):
...
@@ -353,27 +359,14 @@ class VLLM(TemplateLM):
# this is safe to assume because the `grouper` object ensures it.
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs
=
all_gen_kwargs
[
0
]
gen_kwargs
=
all_gen_kwargs
[
0
]
# unpack our keyword arguments.
# unpack our keyword arguments.
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
# add EOS token to stop sequences
until
=
kwargs
.
pop
(
"until"
)
until
=
handle_stop_sequences
(
kwargs
.
pop
(
"until"
,
None
),
eos
=
eos
)
if
isinstance
(
until
,
str
):
until
=
[
until
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
gen_kwargs
}
"
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
)
# add EOS token to stop sequences
eos
=
self
.
tokenizer
.
decode
(
self
.
eot_token_id
)
if
not
until
:
until
=
[
eos
]
else
:
until
.
append
(
eos
)
if
"max_gen_toks"
in
kwargs
.
keys
():
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
else
:
...
...
lm_eval/models/vllm_vlms.py
View file @
0230356c
...
@@ -7,7 +7,12 @@ from tqdm import tqdm
...
@@ -7,7 +7,12 @@ from tqdm import tqdm
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.utils
import
Collator
,
replace_placeholders
,
undistribute
from
lm_eval.models.utils
import
(
Collator
,
handle_stop_sequences
,
replace_placeholders
,
undistribute
,
)
from
lm_eval.models.vllm_causallms
import
VLLM
from
lm_eval.models.vllm_causallms
import
VLLM
from
lm_eval.utils
import
eval_logger
from
lm_eval.utils
import
eval_logger
...
@@ -225,7 +230,7 @@ class VLLM_VLM(VLLM):
...
@@ -225,7 +230,7 @@ class VLLM_VLM(VLLM):
group_fn
=
lambda
x
:
x
[
1
],
group_fn
=
lambda
x
:
x
[
1
],
)
)
chunks
=
re_ords
.
get_batched
(
n
=
self
.
batch_size
,
batch_fn
=
None
)
chunks
=
re_ords
.
get_batched
(
n
=
self
.
batch_size
,
batch_fn
=
None
)
eos
=
self
.
tokenizer
.
decode
(
self
.
eot_token_id
)
for
chunk
in
chunks
:
for
chunk
in
chunks
:
contexts
,
all_gen_kwargs
,
aux_arguments
=
zip
(
*
chunk
)
contexts
,
all_gen_kwargs
,
aux_arguments
=
zip
(
*
chunk
)
...
@@ -241,27 +246,14 @@ class VLLM_VLM(VLLM):
...
@@ -241,27 +246,14 @@ class VLLM_VLM(VLLM):
# this is safe to assume because the `grouper` object ensures it.
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs
=
all_gen_kwargs
[
0
]
gen_kwargs
=
all_gen_kwargs
[
0
]
# unpack our keyword arguments.
# unpack our keyword arguments.
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
kwargs
.
keys
():
# add EOS token to stop sequences
until
=
kwargs
.
pop
(
"until"
)
until
=
handle_stop_sequences
(
kwargs
.
pop
(
"until"
,
None
),
eos
=
eos
)
if
isinstance
(
until
,
str
):
until
=
[
until
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
)
# add EOS token to stop sequences
eos
=
self
.
tokenizer
.
decode
(
self
.
eot_token_id
)
if
not
until
:
until
=
[
eos
]
else
:
until
.
append
(
eos
)
if
"max_gen_toks"
in
kwargs
.
keys
():
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
else
:
...
...
tests/models/test_api.py
View file @
0230356c
...
@@ -63,13 +63,13 @@ def test_create_payload_loglikelihood(api):
...
@@ -63,13 +63,13 @@ def test_create_payload_loglikelihood(api):
(
(
[
"Hello, how are"
],
[
"Hello, how are"
],
True
,
True
,
{
"max_gen_toks"
:
100
,
"temperature"
:
0.7
},
{
"max_gen_toks"
:
100
,
"temperature"
:
0.7
,
"until"
:
[
"hi"
]
},
{
{
"prompt"
:
"Hello, how are"
,
"prompt"
:
"Hello, how are"
,
"model"
:
"gpt-3.5-turbo"
,
"model"
:
"gpt-3.5-turbo"
,
"max_tokens"
:
100
,
"max_tokens"
:
100
,
"temperature"
:
0.7
,
"temperature"
:
0.7
,
"stop"
:
[
"
<|endoftext|>
"
],
"stop"
:
[
"
hi
"
],
"seed"
:
1234
,
"seed"
:
1234
,
},
},
),
),
...
@@ -82,7 +82,7 @@ def test_create_payload_loglikelihood(api):
...
@@ -82,7 +82,7 @@ def test_create_payload_loglikelihood(api):
"model"
:
"gpt-3.5-turbo"
,
"model"
:
"gpt-3.5-turbo"
,
"max_tokens"
:
256
,
"max_tokens"
:
256
,
"temperature"
:
0
,
"temperature"
:
0
,
"stop"
:
[
"<|endoftext|>"
],
"stop"
:
[],
"seed"
:
1234
,
"seed"
:
1234
,
},
},
),
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment