Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0d1ef037
Commit
0d1ef037
authored
Jan 17, 2024
by
lintangsutawika
Browse files
solved merge conflict
parents
aa44be3f
ada4a31d
Changes
424
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
660 additions
and
522 deletions
+660
-522
lm_eval/evaluator.py
lm_eval/evaluator.py
+73
-95
lm_eval/filters/__init__.py
lm_eval/filters/__init__.py
+1
-1
lm_eval/models/__init__.py
lm_eval/models/__init__.py
+1
-0
lm_eval/models/anthropic_llms.py
lm_eval/models/anthropic_llms.py
+39
-24
lm_eval/models/dummy.py
lm_eval/models/dummy.py
+1
-0
lm_eval/models/gguf.py
lm_eval/models/gguf.py
+5
-2
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+150
-142
lm_eval/models/mamba_lm.py
lm_eval/models/mamba_lm.py
+125
-0
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+77
-114
lm_eval/models/textsynth.py
lm_eval/models/textsynth.py
+19
-12
lm_eval/models/vllm_causallms.py
lm_eval/models/vllm_causallms.py
+133
-117
lm_eval/prompts/__init__.py
lm_eval/prompts/__init__.py
+0
-2
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+30
-5
lm_eval/tasks/anli/anli_r1.yaml
lm_eval/tasks/anli/anli_r1.yaml
+1
-1
lm_eval/tasks/arc/arc_easy.yaml
lm_eval/tasks/arc/arc_easy.yaml
+1
-1
lm_eval/tasks/arithmetic/arithmetic_1dc.yaml
lm_eval/tasks/arithmetic/arithmetic_1dc.yaml
+1
-1
lm_eval/tasks/asdiv/default.yaml
lm_eval/tasks/asdiv/default.yaml
+1
-1
lm_eval/tasks/babi/babi.yaml
lm_eval/tasks/babi/babi.yaml
+1
-1
lm_eval/tasks/bbh/_generate_configs.py
lm_eval/tasks/bbh/_generate_configs.py
+0
-2
lm_eval/tasks/bbh/cot_fewshot/_cot_fewshot_template_yaml
lm_eval/tasks/bbh/cot_fewshot/_cot_fewshot_template_yaml
+1
-1
No files found.
lm_eval/evaluator.py
View file @
0d1ef037
import
random
import
random
import
itertools
import
itertools
import
json
import
collections
import
collections
import
sys
import
torch
import
torch
...
@@ -17,8 +15,6 @@ import lm_eval.api.registry
...
@@ -17,8 +15,6 @@ import lm_eval.api.registry
from
lm_eval.utils
import
(
from
lm_eval.utils
import
(
positional_deprecated
,
positional_deprecated
,
run_task_tests
,
run_task_tests
,
make_table
,
create_iterator
,
get_git_commit_hash
,
get_git_commit_hash
,
simple_parse_args_string
,
simple_parse_args_string
,
eval_logger
,
eval_logger
,
...
@@ -91,7 +87,7 @@ def simple_evaluate(
...
@@ -91,7 +87,7 @@ def simple_evaluate(
if
gen_kwargs
is
not
None
:
if
gen_kwargs
is
not
None
:
gen_kwargs
=
simple_parse_args_string
(
gen_kwargs
)
gen_kwargs
=
simple_parse_args_string
(
gen_kwargs
)
eval_logger
.
warning
(
eval_logger
.
warning
(
f
"generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks."
"generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks."
)
)
if
gen_kwargs
==
""
:
if
gen_kwargs
==
""
:
gen_kwargs
=
None
gen_kwargs
=
None
...
@@ -118,7 +114,9 @@ def simple_evaluate(
...
@@ -118,7 +114,9 @@ def simple_evaluate(
use_cache
use_cache
# each rank receives a different cache db.
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
# necessary to avoid multiple writes to cache at once
+
"_rank"
+
str
(
lm
.
rank
)
+
".db"
,
+
"_rank"
+
str
(
lm
.
rank
)
+
".db"
,
)
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
...
@@ -234,9 +232,6 @@ def evaluate(
...
@@ -234,9 +232,6 @@ def evaluate(
padding_requests
=
collections
.
defaultdict
(
int
)
padding_requests
=
collections
.
defaultdict
(
int
)
# store the hierarchy to do proper ordering
# store the hierarchy to do proper ordering
task_hierarchy
=
collections
.
defaultdict
(
list
)
task_hierarchy
=
collections
.
defaultdict
(
list
)
# store the ordering of tasks and groups
task_order
=
collections
.
defaultdict
(
int
)
task_group_alias
=
collections
.
defaultdict
(
dict
)
# store num-fewshot value per task
# store num-fewshot value per task
num_fewshot
=
collections
.
defaultdict
(
int
)
num_fewshot
=
collections
.
defaultdict
(
int
)
...
@@ -264,14 +259,14 @@ def evaluate(
...
@@ -264,14 +259,14 @@ def evaluate(
num_fewshot
[
task_name
]
=
n_shot
num_fewshot
[
task_name
]
=
n_shot
if
"task_alias"
in
configs
[
task_name
]:
if
"task_alias"
in
configs
[
task_name
]:
task_group_alia
s
[
task_name
]
=
configs
[
task_name
][
"task_alias"
]
result
s
[
task_name
]
[
"alias"
]
=
configs
[
task_name
][
"task_alias"
]
if
(
if
(
(
"group_alias"
in
configs
[
task_name
])
(
"group_alias"
in
configs
[
task_name
])
and
(
group_name
not
in
task_group_alia
s
)
and
(
group_name
not
in
result
s
)
and
(
group_name
is
not
None
)
and
(
group_name
is
not
None
)
):
):
task_group_alia
s
[
group_name
]
=
configs
[
task_name
][
"group_alias"
]
result
s
[
group_name
]
[
"alias"
]
=
configs
[
task_name
][
"group_alias"
]
if
limit
is
not
None
:
if
limit
is
not
None
:
if
task
.
has_test_docs
():
if
task
.
has_test_docs
():
...
@@ -440,32 +435,6 @@ def evaluate(
...
@@ -440,32 +435,6 @@ def evaluate(
vals
=
vals_torch
vals
=
vals_torch
if
lm
.
rank
==
0
:
if
lm
.
rank
==
0
:
### Get task ordering for correct sample-wide aggregation
group_to_task
=
{}
for
group
in
task_hierarchy
.
keys
():
if
group
not
in
task_order
:
task_order
[
group
]
=
0
if
len
(
task_hierarchy
[
group
])
>
0
:
group_to_task
[
group
]
=
task_hierarchy
[
group
].
copy
()
for
task
in
task_hierarchy
[
group
]:
if
task
in
task_order
:
task_order
[
task
]
+=
1
else
:
task_order
[
task
]
=
1
+
task_order
[
group
]
if
task
in
task_hierarchy
:
group_to_task
[
group
].
remove
(
task
)
group_to_task
[
group
].
extend
(
task_hierarchy
[
task
])
task_to_group
=
{}
for
group
in
group_to_task
:
for
task
in
group_to_task
[
group
]:
if
task
in
task_to_group
:
task_to_group
[
task
].
append
(
group
)
else
:
task_to_group
[
task
]
=
[
group
]
### Aggregate results over all datapoints ###
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
# aggregate results ; run bootstrap CIs
...
@@ -505,7 +474,10 @@ def evaluate(
...
@@ -505,7 +474,10 @@ def evaluate(
total_size
=
0
total_size
=
0
for
task
in
task_list
:
for
task
in
task_list
:
metrics
=
results
[
task
]
metrics
=
results
[
task
].
copy
()
if
"alias"
in
metrics
:
metrics
.
pop
(
"alias"
)
current_size
=
metrics
.
pop
(
"samples"
)
current_size
=
metrics
.
pop
(
"samples"
)
# TODO: There should be a way for users
# TODO: There should be a way for users
...
@@ -564,71 +536,77 @@ def evaluate(
...
@@ -564,71 +536,77 @@ def evaluate(
results
[
group
][
"samples"
]
=
total_size
results
[
group
][
"samples"
]
=
total_size
def
print_tasks
(
task_hierarchy
,
task_order
,
task_version
,
task_group_alias
):
def
print_tasks
(
task_hierarchy
,
results
,
tab
=
0
):
results_agg
=
collections
.
defaultdict
(
dict
)
results_agg
=
collections
.
defaultdict
(
dict
)
groups_agg
=
collections
.
defaultdict
(
dict
)
groups_agg
=
collections
.
defaultdict
(
dict
)
for
group_name
,
task_list
in
task_hierarchy
.
items
():
order
=
task_order
[
group_name
]
results_agg
[
group_name
]
=
results
[
group_name
].
copy
()
results_agg
[
group_name
][
"tab"
]
=
order
if
(
order
<
max
(
task_order
.
values
()))
and
(
len
(
task_list
)
>
0
):
(
group_name
,
task_list
),
*
_
=
task_hierarchy
.
items
()
groups_agg
[
group_name
]
=
results
[
group_name
].
copy
()
task_list
=
sorted
(
task_list
)
groups_agg
[
group_name
][
"tab"
]
=
order
if
task_list
!=
[]:
results_agg
[
group_name
]
=
results
[
group_name
].
copy
()
for
task
in
sorted
(
task_list
):
# results_agg[group_name]["tab"] = tab
if
task
in
task_hierarchy
:
if
"samples"
in
results_agg
[
group_name
]:
_task_hierarchy
=
{
task
:
task_hierarchy
[
task
]}
results_agg
[
group_name
].
pop
(
"samples"
)
else
:
_task_hierarchy
=
{
task
:
[]}
_results_agg
,
_groups_agg
,
task_version
=
print_tasks
(
_task_hierarchy
,
task_order
,
task_version
,
task_group_alias
)
results_agg
=
{
**
results_agg
,
**
_results_agg
}
groups_agg
=
{
**
groups_agg
,
**
_groups_agg
}
return
results_agg
,
groups_agg
,
task_version
results_agg
,
groups_agg
,
versions
=
print_tasks
(
task_hierarchy
,
task_order
,
versions
,
task_group_alias
)
for
task
in
results_agg
:
tab_string
=
" "
*
tab
+
"- "
if
tab
>
0
else
""
task_results
=
results_agg
[
task
]
if
"samples"
in
task_results
:
if
"alias"
in
results_agg
[
group_name
]:
task_results
.
pop
(
"samples"
)
results_agg
[
group_name
][
"alias"
]
=
(
tab_string
+
results_agg
[
group_name
][
"alias"
]
tab_string
=
""
)
if
"tab"
in
task_results
:
tab
=
task_results
.
pop
(
"tab"
)
tab_string
=
" "
*
tab
+
"- "
if
tab
>
0
else
""
if
task
in
task_group_alias
:
task_alias
=
task_group_alias
[
task
]
results_agg
[
task
][
"alias"
]
=
tab_string
+
task_alias
else
:
else
:
results_agg
[
task
][
"alias"
]
=
tab_string
+
task
results_agg
[
group_name
][
"alias"
]
=
tab_string
+
group_name
for
group
in
groups_agg
:
group_results
=
groups_agg
[
group
]
if
"samples"
in
group_results
:
group_results
.
pop
(
"samples"
)
tab_string
=
""
if
len
(
task_list
)
>
0
:
if
"tab"
in
group_results
:
groups_agg
[
group_name
]
=
results
[
group_name
].
copy
()
tab
=
group_results
.
pop
(
"tab"
)
# groups_agg[group_name]["tab"] = tab
tab_string
=
" "
*
tab
+
"- "
if
tab
>
0
else
""
if
"samples"
in
groups_agg
[
group_name
]:
groups_agg
[
group_name
].
pop
(
"samples"
)
if
group
in
task_group_alias
:
if
"alias"
in
groups_agg
[
group_name
]:
group_alias
=
task_group_alias
[
group
]
groups_agg
[
group_name
][
"alias"
]
=
(
groups_agg
[
group
][
"alias"
]
=
tab_string
+
group_alias
tab_string
+
groups_agg
[
group_name
][
"alias"
]
else
:
)
groups_agg
[
group
][
"alias"
]
=
tab_string
+
group
else
:
groups_agg
[
group_name
][
"alias"
]
=
tab_string
+
group_name
for
task_name
in
task_list
:
if
task_name
in
task_hierarchy
:
_task_hierarchy
=
{
**
{
task_name
:
task_hierarchy
[
task_name
]},
**
task_hierarchy
,
}
else
:
_task_hierarchy
=
{
**
{
task_name
:
[]},
**
task_hierarchy
,
}
_results_agg
,
_groups_agg
=
print_tasks
(
_task_hierarchy
,
results
,
tab
+
1
)
results_agg
=
{
**
results_agg
,
**
_results_agg
}
groups_agg
=
{
**
groups_agg
,
**
_groups_agg
}
return
results_agg
,
groups_agg
results_agg
=
collections
.
defaultdict
(
dict
)
groups_agg
=
collections
.
defaultdict
(
dict
)
all_tasks_list
=
list
(
task_hierarchy
.
keys
())
left_tasks_list
=
[]
while
True
:
add_tasks_list
=
list
(
k
for
k
in
results_agg
.
keys
())
left_tasks_list
=
sorted
(
list
(
set
(
all_tasks_list
)
-
set
(
add_tasks_list
)))
if
len
(
left_tasks_list
)
==
0
:
break
_task_hierarchy
=
{
k
:
v
for
k
,
v
in
task_hierarchy
.
items
()
if
k
in
left_tasks_list
}
_results_agg
,
_groups_agg
=
print_tasks
(
_task_hierarchy
,
results
)
results_agg
=
{
**
results_agg
,
**
_results_agg
}
groups_agg
=
{
**
groups_agg
,
**
_groups_agg
}
for
group_name
,
task_list
in
task_hierarchy
.
items
():
for
group_name
,
task_list
in
task_hierarchy
.
items
():
if
task_list
!=
[]:
if
task_list
!=
[]:
...
...
lm_eval/filters/__init__.py
View file @
0d1ef037
...
@@ -32,7 +32,7 @@ def build_filter_ensemble(filter_name, components):
...
@@ -32,7 +32,7 @@ def build_filter_ensemble(filter_name, components):
Create a filtering pipeline.
Create a filtering pipeline.
"""
"""
filters
=
[]
filters
=
[]
for
(
function
,
kwargs
)
in
components
:
for
function
,
kwargs
in
components
:
if
kwargs
is
None
:
if
kwargs
is
None
:
f
=
get_filter
(
function
)()
f
=
get_filter
(
function
)()
else
:
else
:
...
...
lm_eval/models/__init__.py
View file @
0d1ef037
...
@@ -5,5 +5,6 @@ from . import dummy
...
@@ -5,5 +5,6 @@ from . import dummy
from
.
import
anthropic_llms
from
.
import
anthropic_llms
from
.
import
gguf
from
.
import
gguf
from
.
import
vllm_causallms
from
.
import
vllm_causallms
from
.
import
mamba_lm
# TODO: implement __all__
# TODO: implement __all__
lm_eval/models/anthropic_llms.py
View file @
0d1ef037
from
lm_eval.api.model
import
LM
from
typing
import
Any
,
List
,
Tuple
from
lm_eval.api.registry
import
register_model
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
time
from
lm_eval
import
utils
from
lm_eval
import
utils
from
typing
import
List
,
Any
,
Tuple
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.utils
import
retry_on_specific_exceptions
eval_logger
=
utils
.
eval_logger
eval_logger
=
utils
.
eval_logger
...
@@ -45,26 +48,30 @@ def anthropic_completion(
...
@@ -45,26 +48,30 @@ def anthropic_completion(
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`"
,
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`"
,
)
)
backoff_time
:
float
=
3
def
_exception_callback
(
e
:
Exception
,
sleep_time
:
float
)
->
None
:
while
True
:
eval_logger
.
warning
(
try
:
f
"RateLimitError occurred:
{
e
.
__cause__
}
\n
Retrying in
{
sleep_time
}
seconds"
response
=
client
.
completions
.
create
(
)
prompt
=
f
"
{
anthropic
.
HUMAN_PROMPT
}
{
prompt
}{
anthropic
.
AI_PROMPT
}
"
,
model
=
model
,
@
retry_on_specific_exceptions
(
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
on_exceptions
=
[
anthropic
.
RateLimitError
],
# (e.g. gsm8k's ":") may truncate a lot of the input.
max_retries
=
None
,
# retry forever, consider changing
stop_sequences
=
[
anthropic
.
HUMAN_PROMPT
]
+
stop
,
on_exception_callback
=
_exception_callback
,
max_tokens_to_sample
=
max_tokens_to_sample
,
)
temperature
=
temperature
,
def
completion
():
**
kwargs
,
response
=
client
.
completions
.
create
(
)
prompt
=
f
"
{
anthropic
.
HUMAN_PROMPT
}
{
prompt
}{
anthropic
.
AI_PROMPT
}
"
,
return
response
.
completion
model
=
model
,
except
anthropic
.
RateLimitError
as
e
:
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
eval_logger
.
warning
(
# (e.g. gsm8k's ":") may truncate a lot of the input.
f
"RateLimitError occurred:
{
e
.
__cause__
}
\n
Retrying in
{
backoff_time
}
seconds"
stop_sequences
=
[
anthropic
.
HUMAN_PROMPT
]
+
stop
,
)
max_tokens_to_sample
=
max_tokens_to_sample
,
time
.
sleep
(
backoff_time
)
temperature
=
temperature
,
backoff_time
*=
1.5
**
kwargs
,
)
return
response
.
completion
return
completion
()
@
register_model
(
"anthropic"
)
@
register_model
(
"anthropic"
)
...
@@ -141,6 +148,14 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
...
@@ -141,6 +148,14 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
raise
NotImplementedError
(
"No support for logits."
)
raise
NotImplementedError
(
"No support for logits."
)
def
generate_until
(
self
,
requests
)
->
List
[
str
]:
def
generate_until
(
self
,
requests
)
->
List
[
str
]:
try
:
import
anthropic
except
ModuleNotFoundError
:
raise
Exception
(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed.
\
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`"
,
)
if
not
requests
:
if
not
requests
:
return
[]
return
[]
...
...
lm_eval/models/dummy.py
View file @
0d1ef037
import
random
import
random
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
...
...
lm_eval/models/gguf.py
View file @
0d1ef037
import
requests
import
logging
import
logging
import
time
import
time
from
tqdm
import
tqdm
import
requests
from
requests.exceptions
import
RequestException
from
requests.exceptions
import
RequestException
from
tqdm
import
tqdm
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
lm_eval/models/huggingface.py
View file @
0d1ef037
This diff is collapsed.
Click to expand it.
lm_eval/models/mamba_lm.py
0 → 100644
View file @
0d1ef037
from
typing
import
Optional
,
Union
import
torch
from
lm_eval
import
utils
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.huggingface
import
HFLM
@
register_model
(
"mamba_ssm"
)
class
MambaLMWrapper
(
HFLM
):
def
__init__
(
self
,
pretrained
=
"state-spaces/mamba-130m"
,
**
kwargs
,
)
->
None
:
"""
Mamba (via the `mamba_ssm` package) supports the following args:
```
d_model: int,
n_layer: int,
vocab_size: int,
initializer_cfg=None,
pad_vocab_size_multiple: int = 1,
ssm_cfg=None,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
initializer_cfg=None,
fused_add_norm=False,
residual_in_fp32=False,
```
See https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L175 for more info.
The above can all be passed via `--model_args` or to this __init__() directly
but we recommend placing many of these within the config.json file uploaded alongside your
Mamba model to the HF Hub instead.
All other HuggingFace from_pretrained() kwargs
such as those related to
`parallelize=True`, PEFT, autoGPTQ,
or any sub-configurations of these advanced args,
are unsupported by the `mamba_ssm` package.
The HFLM arguments
`backend`, `revision`, `subfolder`, `tokenizer`, `truncation`, `max_length`,
`device`, `dtype`, `batch_size`, `max_batch_size`, `trust_remote_code`, `use_fast_tokenizer`
Are all supported by Mamba where they do not conflict
with Mamba-specific restrictions such as causal LMs only.
"""
if
"backend"
in
kwargs
:
# mamba currently only supports causal models
assert
kwargs
[
"backend"
]
==
"causal"
super
().
__init__
(
pretrained
=
pretrained
,
# set appropriate defaults for tokenizer, max length, etc
backend
=
kwargs
.
get
(
"backend"
,
"causal"
),
tokenizer
=
kwargs
.
get
(
"tokenizer"
,
"EleutherAI/gpt-neox-20b"
),
max_length
=
kwargs
.
get
(
"max_length"
,
2048
),
**
kwargs
,
)
def
_get_config
(
self
,
pretrained
:
str
,
**
kwargs
,
)
->
None
:
try
:
from
mamba_ssm.utils.hf
import
load_config_hf
# noqa: F811
except
ModuleNotFoundError
:
raise
Exception
(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed.
\
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`"
,
)
self
.
_config
=
load_config_hf
(
pretrained
)
def
_create_model
(
self
,
pretrained
:
str
,
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
"float16"
,
# no `parallelize=True` options
# no PEFT and quantization options
# Mamba does not support arbitrary HF from_pretrained() args
**
kwargs
,
)
->
None
:
try
:
from
mamba_ssm.models.mixer_seq_simple
import
MambaLMHeadModel
# noqa: F811
except
ModuleNotFoundError
:
raise
Exception
(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed.
\
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`"
,
)
self
.
_model
=
MambaLMHeadModel
.
from_pretrained
(
pretrained
,
device
=
self
.
_device
,
dtype
=
torch
.
float16
if
dtype
==
"auto"
else
utils
.
get_dtype
(
dtype
),
**
kwargs
,
)
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
for
key
in
(
"do_sample"
,
"attention_mask"
):
if
key
in
generation_kwargs
:
generation_kwargs
.
pop
(
key
)
# mamba's custom GenerationMixin currently does not support
# passing stopping criteria.
# for the time being, we simply generate to max length,
# then truncate (equivalent result)
# -- this should be revisited to speed up generation
# stopping_criteria = stop_sequences_criteria(
# self.tokenizer, stop, 1, context.shape[0]
# )
return
self
.
model
.
generate
(
input_ids
=
context
,
max_length
=
max_length
,
# stopping_criteria=stopping_criteria,
# pad_token_id=self.tokenizer.pad_token_id,
# use_cache=True,
**
generation_kwargs
,
)
lm_eval/models/openai_completions.py
View file @
0d1ef037
import
os
import
time
from
typing
import
List
,
Tuple
,
Optional
import
copy
import
copy
import
os
from
collections
import
defaultdict
from
collections
import
defaultdict
from
importlib.util
import
find_spec
from
typing
import
List
,
Optional
,
Tuple
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval.utils
import
retry_on_specific_exceptions
def
get_result
(
response
,
ctxlen
:
int
)
->
Tuple
[
float
,
bool
]:
def
get_result
(
response
,
ctxlen
:
int
)
->
Tuple
[
float
,
bool
]:
...
@@ -44,24 +45,28 @@ def oa_completion(**kwargs):
...
@@ -44,24 +45,28 @@ def oa_completion(**kwargs):
Retry with back-off until they respond
Retry with back-off until they respond
"""
"""
try
:
if
not
find_spec
(
"openai"
)
or
not
find_spec
(
"tiktoken"
):
import
openai
,
tiktoken
# noqa: E401
except
ModuleNotFoundError
:
raise
Exception
(
raise
Exception
(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed.
\
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed.
"
p
lease install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
,
"P
lease install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
)
)
else
:
import
openai
backoff_time
=
3
def
_exception_callback
(
e
:
Exception
,
sleep_time
:
float
)
->
None
:
while
True
:
import
traceback
try
:
return
openai
.
completions
.
create
(
**
kwargs
)
traceback
.
print_exc
()
except
openai
.
OpenAIError
:
import
traceback
traceback
.
print_exc
()
@
retry_on_specific_exceptions
(
time
.
sleep
(
backoff_time
)
on_exceptions
=
[
openai
.
OpenAIError
],
backoff_time
*=
1.5
max_retries
=
None
,
# retry forever, consider changing
on_exception_callback
=
_exception_callback
,
)
def
completion
():
return
openai
.
completions
.
create
(
**
kwargs
)
return
completion
()
@
register_model
(
"openai-completions"
)
@
register_model
(
"openai-completions"
)
...
@@ -71,7 +76,7 @@ class OpenaiCompletionsLM(LM):
...
@@ -71,7 +76,7 @@ class OpenaiCompletionsLM(LM):
def
__init__
(
def
__init__
(
self
,
self
,
model
:
str
=
"text-davinci-003"
,
model
:
str
,
truncate
:
bool
=
False
,
truncate
:
bool
=
False
,
max_gen_toks
:
int
=
256
,
max_gen_toks
:
int
=
256
,
batch_size
:
int
=
1
,
batch_size
:
int
=
1
,
...
@@ -81,14 +86,15 @@ class OpenaiCompletionsLM(LM):
...
@@ -81,14 +86,15 @@ class OpenaiCompletionsLM(LM):
"""
"""
:param engine: str
:param engine: str
OpenAI API engine (e.g.
davinci
)
OpenAI API engine (e.g.
gpt-3.5-turbo-instruct
)
:param truncate: bool
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
Truncate input if too long (if False and input is too long, throw error)
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
seed
=
seed
self
.
seed
=
seed
try
:
try
:
import
openai
,
tiktoken
# noqa: E401
import
openai
# noqa: E401
import
tiktoken
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
raise
Exception
(
raise
Exception
(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed.
\
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed.
\
...
@@ -102,7 +108,7 @@ class OpenaiCompletionsLM(LM):
...
@@ -102,7 +108,7 @@ class OpenaiCompletionsLM(LM):
self
.
_max_gen_toks
=
max_gen_toks
self
.
_max_gen_toks
=
max_gen_toks
self
.
_max_length
=
max_length
self
.
_max_length
=
max_length
# Read from environment variable OPENAI_API_
SECRET_
KEY
# Read from environment variable OPENAI_API_KEY
openai
.
api_key
=
os
.
environ
[
"OPENAI_API_KEY"
]
openai
.
api_key
=
os
.
environ
[
"OPENAI_API_KEY"
]
@
property
@
property
...
@@ -154,8 +160,9 @@ class OpenaiCompletionsLM(LM):
...
@@ -154,8 +160,9 @@ class OpenaiCompletionsLM(LM):
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
if
context
==
""
:
if
context
==
""
:
# end of text as context
# end of text as context
context_enc
,
continuation_enc
=
[
self
.
eot_token_id
],
self
.
tok_encode
(
context_enc
,
continuation_enc
=
(
continuation
[
self
.
eot_token_id
],
self
.
tok_encode
(
continuation
),
)
)
else
:
else
:
context_enc
,
continuation_enc
=
self
.
_encode_pair
(
context
,
continuation
)
context_enc
,
continuation_enc
=
self
.
_encode_pair
(
context
,
continuation
)
...
@@ -247,6 +254,7 @@ class OpenaiCompletionsLM(LM):
...
@@ -247,6 +254,7 @@ class OpenaiCompletionsLM(LM):
list
(
sameuntil_chunks
(
re_ord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
))
list
(
sameuntil_chunks
(
re_ord
.
get_reordered
(),
self
.
REQ_CHUNK_SIZE
))
):
):
inps
=
[]
inps
=
[]
self
.
_max_gen_toks
=
request_args
.
pop
(
"max_gen_toks"
,
self
.
max_gen_toks
)
for
context
,
_
in
chunk
:
for
context
,
_
in
chunk
:
context_enc
=
self
.
tok_encode
(
context
)
context_enc
=
self
.
tok_encode
(
context
)
inp
=
context_enc
[
-
(
self
.
max_length
-
self
.
max_gen_toks
)
:]
inp
=
context_enc
[
-
(
self
.
max_length
-
self
.
max_gen_toks
)
:]
...
@@ -326,68 +334,68 @@ def oa_chat_completion(client, **kwargs):
...
@@ -326,68 +334,68 @@ def oa_chat_completion(client, **kwargs):
Retry with back-off until they respond
Retry with back-off until they respond
"""
"""
try
:
if
not
find_spec
(
"openai"
)
or
not
find_spec
(
"tiktoken"
):
import
openai
,
tiktoken
# noqa: E401
except
ModuleNotFoundError
:
raise
Exception
(
raise
Exception
(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed.
\
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed.
"
p
lease install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
,
"P
lease install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
)
)
else
:
import
openai
async
def
_get_completions
(
**
kwargs
):
def
_exception_callback
(
e
:
Exception
,
sleep_time
:
float
)
->
None
:
chat_completions
=
await
client
.
chat
.
completions
.
create
(
**
kwargs
)
import
traceback
return
chat_completions
backoff_time
=
3
traceback
.
print_exc
()
while
True
:
try
:
@
retry_on_specific_exceptions
(
return
client
.
chat
.
completions
.
create
(
**
kwargs
)
on_exceptions
=
[
openai
.
OpenAIError
],
except
openai
.
OpenAIError
:
max_retries
=
None
,
# retry forever, consider changing
import
traceback
on_exception_callback
=
_exception_callback
,
)
def
completion
():
return
client
.
chat
.
completions
.
create
(
**
kwargs
)
traceback
.
print_exc
()
return
completion
()
time
.
sleep
(
backoff_time
)
backoff_time
*=
1.5
@
register_model
(
"openai-chat-completions"
)
@
register_model
(
"openai-chat-completions"
,
"local-chat-completions"
)
class
OpenaiChatCompletionsLM
(
LM
):
class
OpenaiChatCompletionsLM
(
LM
):
def
__init__
(
def
__init__
(
self
,
model
:
str
=
"gpt-3.5-turbo"
,
truncate
:
bool
=
False
,
batch_size
:
int
=
1
self
,
model
:
str
=
"gpt-3.5-turbo"
,
# GPT model or Local model using HuggingFace model paths
base_url
:
str
=
None
,
truncate
:
bool
=
False
,
**
kwargs
,
)
->
None
:
)
->
None
:
"""
"""
:param model: str
:param model: str
Implements an OpenAI-style chat completion API for
accessing both OpenAI OR locally-hosted models using
HuggingFace Tokenizer
OpenAI API model (e.g. gpt-3.5-turbo)
OpenAI API model (e.g. gpt-3.5-turbo)
using the **gen_kwargs passed on init
:param truncate: bool
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
Truncate input if too long (if False and input is too long, throw error)
"""
"""
super
().
__init__
()
super
().
__init__
()
try
:
try
:
import
openai
,
tiktoken
# noqa: E401
import
openai
# noqa: E401
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
raise
Exception
(
raise
Exception
(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed.
\
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed.
\
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
,
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
,
)
)
self
.
model
=
model
self
.
model
=
model
self
.
frequency_penalty
=
0
self
.
base_url
=
base_url
self
.
logit_bias
=
None
self
.
n
=
1
self
.
presence_penalty
=
0
self
.
temperature
=
1
self
.
top_p
=
1
self
.
tokenizer
=
tiktoken
.
encoding_for_model
(
self
.
model
)
self
.
vocab_size
=
self
.
tokenizer
.
n_vocab
self
.
truncate
=
truncate
self
.
truncate
=
truncate
self
.
end_of_text_token_id
=
self
.
tokenizer
.
eot_token
# Read from environment variable OPENAI_API_KEY
# Read from environment variable OPENAI_API_KEY
self
.
client
=
openai
.
OpenAI
()
# openai.AsyncOpenAI()
# Set to EMPTY for local
if
self
.
base_url
:
@
property
self
.
client
=
openai
.
OpenAI
(
base_url
=
self
.
base_url
)
def
eot_token_id
(
self
)
:
else
:
return
self
.
end_of_text_token_id
self
.
client
=
openai
.
OpenAI
()
# openai.AsyncOpenAI()
@
property
@
property
def
max_length
(
self
)
->
int
:
def
max_length
(
self
)
->
int
:
...
@@ -408,53 +416,19 @@ class OpenaiChatCompletionsLM(LM):
...
@@ -408,53 +416,19 @@ class OpenaiChatCompletionsLM(LM):
# Isn't used because we override _loglikelihood_tokens
# Isn't used because we override _loglikelihood_tokens
raise
NotImplementedError
()
raise
NotImplementedError
()
def
tok_encode
(
self
,
string
:
str
)
->
List
[
int
]:
return
self
.
tokenizer
.
encode
(
string
)
def
tok_decode
(
self
,
tokens
:
List
[
int
])
->
str
:
return
self
.
tokenizer
.
decode
(
tokens
)
def
_encode_pair
(
self
,
context
:
str
,
continuation
:
str
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
if
n_spaces
>
0
:
continuation
=
context
[
-
n_spaces
:]
+
continuation
context
=
context
[:
-
n_spaces
]
whole_enc
=
self
.
tok_encode
(
context
+
continuation
)
context_enc
=
self
.
tok_encode
(
context
)
context_enc_len
=
len
(
context_enc
)
continuation_enc
=
whole_enc
[
context_enc_len
:]
return
context_enc
,
continuation_enc
def
generate_until
(
self
,
requests
)
->
List
[
str
]:
def
generate_until
(
self
,
requests
)
->
List
[
str
]:
res
=
defaultdict
(
list
)
res
=
defaultdict
(
list
)
re_ords
=
{}
re_ords
=
{}
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
return
-
len
(
toks
),
x
[
0
]
# we group requests by their generation_kwargs,
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
# in the same batch.
grouper
=
utils
.
Grouper
(
requests
,
lambda
x
:
str
(
x
.
args
[
1
]))
grouper
=
utils
.
Grouper
(
requests
,
lambda
x
:
str
(
x
.
args
[
1
]))
for
key
,
reqs
in
grouper
.
get_grouped
().
items
():
for
key
,
reqs
in
grouper
.
get_grouped
().
items
():
# within each set of reqs for given kwargs, we reorder by token length, descending.
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords
[
key
]
=
utils
.
Reorderer
([
req
.
args
for
req
in
reqs
],
_collate
)
re_ords
[
key
]
=
utils
.
Reorderer
(
[
req
.
args
for
req
in
reqs
],
lambda
x
:
(
-
len
(
x
[
0
]),
x
[
0
])
def
sameuntil_chunks
(
xs
,
size
):
)
ret
=
[]
lastuntil
=
xs
[
0
][
1
]
for
x
in
xs
:
if
len
(
ret
)
>=
size
or
x
[
1
]
!=
lastuntil
:
yield
ret
,
lastuntil
ret
=
[]
lastuntil
=
x
[
1
]
ret
.
append
(
x
)
if
ret
:
yield
ret
,
lastuntil
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
self
.
rank
!=
0
))
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
self
.
rank
!=
0
))
for
key
,
re_ord
in
re_ords
.
items
():
for
key
,
re_ord
in
re_ords
.
items
():
...
@@ -468,37 +442,26 @@ class OpenaiChatCompletionsLM(LM):
...
@@ -468,37 +442,26 @@ class OpenaiChatCompletionsLM(LM):
gen_kwargs
=
all_gen_kwargs
[
0
]
gen_kwargs
=
all_gen_kwargs
[
0
]
until
=
None
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
if
isinstance
(
kwargs
:
=
copy
.
deepcopy
(
gen_kwargs
),
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"do_sample"
in
kwargs
.
keys
():
kwargs
.
pop
(
"do_sample"
)
if
"until"
in
kwargs
.
keys
():
if
"until"
in
kwargs
.
keys
():
until
=
kwargs
.
pop
(
"until"
)
until
=
kwargs
.
pop
(
"until"
)
if
isinstance
(
until
,
str
):
if
isinstance
(
until
,
str
):
until
=
[
kwargs
]
until
=
[
kwargs
]
elif
not
isinstance
(
until
,
list
):
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
raise
ValueError
(
f
"Expected
`
kwargs['until']
`
to be of type Union[str,list] but got
{
until
}
"
f
"Expected
repr(
kwargs['until']
)
to be of type Union[str,
list] but got
{
until
}
"
)
)
kwargs
[
"stop"
]
=
until
kwargs
[
"max_tokens"
]
=
kwargs
.
pop
(
"max_gen_toks"
,
self
.
max_gen_toks
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Expected
`
kwargs
`
to be of type
`
dict
`
but got
{
kwargs
}
"
f
"Expected
repr(
kwargs
)
to be of type
repr(
dict
)
but got
{
kwargs
}
"
)
)
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
max_gen_toks
=
self
.
max_gen_toks
response
=
oa_chat_completion
(
response
=
oa_chat_completion
(
client
=
self
.
client
,
client
=
self
.
client
,
messages
=
inps
,
model
=
self
.
model
,
**
kwargs
messages
=
inps
,
model
=
self
.
model
,
frequency_penalty
=
self
.
frequency_penalty
,
# logit_bias=self.logit_bias,
max_tokens
=
max_gen_toks
,
n
=
self
.
n
,
presence_penalty
=
self
.
presence_penalty
,
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
)
)
for
resp
,
(
context
,
args_
)
in
zip
(
response
.
choices
,
chunk
):
for
resp
,
(
context
,
args_
)
in
zip
(
response
.
choices
,
chunk
):
...
...
lm_eval/models/textsynth.py
View file @
0d1ef037
...
@@ -13,11 +13,13 @@ Homepage: https://textsynth.com/index.html
...
@@ -13,11 +13,13 @@ Homepage: https://textsynth.com/index.html
"""
"""
import
logging
import
logging
import
os
import
os
import
requests
as
_requests
import
requests
as
_requests
import
time
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval.utils
import
retry_on_specific_exceptions
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -27,21 +29,26 @@ def textsynth_completion(**kwargs):
...
@@ -27,21 +29,26 @@ def textsynth_completion(**kwargs):
"""Query TextSynth API for completion.
"""Query TextSynth API for completion.
Retry with back-off until they respond.
Retry with back-off until they respond.
"""
"""
backoff_time
=
3
while
True
:
try
:
return
_requests
.
post
(
**
kwargs
)
except
_requests
.
exceptions
.
RequestException
:
import
traceback
traceback
.
print_exc
()
def
_exception_callback
(
e
:
Exception
,
sleep_time
:
float
)
->
None
:
time
.
sleep
(
backoff_time
)
import
traceback
backoff_time
*=
1.5
traceback
.
print_exc
()
@
retry_on_specific_exceptions
(
on_exceptions
=
[
_requests
.
exceptions
.
RequestException
],
max_retries
=
None
,
# retry forever, consider changing
on_exception_callback
=
_exception_callback
,
)
def
completion
():
return
_requests
.
post
(
**
kwargs
)
return
completion
()
@
register_model
(
"textsynth"
)
@
register_model
(
"textsynth"
)
class
TextSynthLM
(
LM
):
class
TextSynthLM
(
LM
):
def
__init__
(
self
,
engine
,
truncate
:
bool
=
False
)
->
None
:
def
__init__
(
self
,
engine
,
truncate
:
bool
=
False
,
**
kwargs
)
->
None
:
"""
"""
:param engine: str
:param engine: str
TextSynth API engine (e.g. `gptj_6B`)
TextSynth API engine (e.g. `gptj_6B`)
...
@@ -149,7 +156,7 @@ class TextSynthLM(LM):
...
@@ -149,7 +156,7 @@ class TextSynthLM(LM):
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
inp
,
request_args
),
s
)
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
inp
,
request_args
),
s
)
else
:
else
:
logger
.
error
(
logger
.
error
(
f
"The following response does not contain generated `text`. "
"The following response does not contain generated `text`. "
"Got:
\n
{resp}"
"Got:
\n
{resp}"
)
)
assert
False
assert
False
...
...
lm_eval/models/vllm_causallms.py
View file @
0d1ef037
from
collections
import
defaultdict
from
typing
import
List
,
Tuple
,
Optional
,
Literal
,
Union
,
Any
from
transformers
import
AutoTokenizer
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.model
import
LM
import
copy
import
copy
from
importlib.util
import
find_spec
from
typing
import
List
,
Literal
,
Optional
,
Tuple
,
Union
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval
import
utils
from
lm_eval.utils
import
(
Collator
,
divide
,
eval_logger
,
get_rolling_token_windows
,
make_disjoint_window
,
)
try
:
try
:
from
vllm
import
LLM
,
SamplingParams
import
ray
from
ray.util.multiprocessing
import
Pool
from
ray.util.multiprocessing
import
Pool
from
vllm
import
LLM
,
SamplingParams
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
pass
pass
eval_logger
=
utils
.
eval_logger
eval_logger
=
eval_logger
# adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727
# adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727
def
run_inference_one_model
(
model_args
:
dict
,
sampling_params
,
requests
:
List
[
int
]):
def
run_inference_one_model
(
# gpu_id = [x for x in gpu_id
]
model_args
:
dict
,
sampling_params
,
requests
:
List
[
List
[
int
]
]
# os.environ["CUDA_VISIBLE_DEVICES"]= str(gpu_id
)
)
:
llm
=
LLM
(
**
model_args
)
llm
=
LLM
(
**
model_args
)
return
llm
.
generate
(
prompt_token_ids
=
requests
,
sampling_params
=
sampling_params
)
return
llm
.
generate
(
prompt_token_ids
=
requests
,
sampling_params
=
sampling_params
)
...
@@ -40,7 +49,7 @@ class VLLM(LM):
...
@@ -40,7 +49,7 @@ class VLLM(LM):
tokenizer_mode
:
Literal
[
"auto"
,
"slow"
]
=
"auto"
,
tokenizer_mode
:
Literal
[
"auto"
,
"slow"
]
=
"auto"
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
tensor_parallel_size
:
int
=
1
,
tensor_parallel_size
:
int
=
1
,
quantization
:
Optional
[
Literal
[
"awq"
]
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
max_gen_toks
:
int
=
256
,
max_gen_toks
:
int
=
256
,
swap_space
:
int
=
4
,
swap_space
:
int
=
4
,
batch_size
:
Union
[
str
,
int
]
=
1
,
batch_size
:
Union
[
str
,
int
]
=
1
,
...
@@ -54,12 +63,10 @@ class VLLM(LM):
...
@@ -54,12 +63,10 @@ class VLLM(LM):
):
):
super
().
__init__
()
super
().
__init__
()
try
:
if
not
find_spec
(
"vllm"
):
import
vllm
except
ModuleNotFoundError
:
raise
Exception
(
raise
Exception
(
"attempted to use 'vllm' LM type, but package `vllm` is not installed.
\
"attempted to use 'vllm' LM type, but package `vllm` is not installed.
"
p
lease install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
,
"P
lease install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
)
)
assert
"cuda"
in
device
or
device
is
None
,
"vLLM only supports CUDA"
assert
"cuda"
in
device
or
device
is
None
,
"vLLM only supports CUDA"
...
@@ -85,17 +92,30 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -85,17 +92,30 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
"quantization"
:
quantization
,
"quantization"
:
quantization
,
"seed"
:
int
(
seed
),
"seed"
:
int
(
seed
),
}
}
self
.
batch_size
=
(
"auto"
if
isinstance
(
batch_size
,
str
)
and
"auto"
in
batch_size
else
batch_size
)
if
self
.
data_parallel_size
<=
1
:
if
self
.
data_parallel_size
<=
1
:
self
.
model
=
LLM
(
**
self
.
model_args
)
self
.
model
=
LLM
(
**
self
.
model_args
)
else
:
else
:
self
.
model_args
[
"worker_use_ray"
]
=
True
self
.
model_args
[
"worker_use_ray"
]
=
True
self
.
batch_size
=
"auto"
eval_logger
.
info
(
"Manual batching is not compatible with data parallelism."
)
from
transformers
import
AutoConfig
self
.
_config
=
AutoConfig
.
from_pretrained
(
pretrained
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
)
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer
=
get_tokenizer
(
tokenizer
if
tokenizer
else
pretrained
,
tokenizer
if
tokenizer
else
pretrained
,
tokenizer_mode
=
tokenizer_mode
,
tokenizer_mode
=
tokenizer_mode
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
tokenizer_revision
=
tokenizer_revision
,
tokenizer_revision
=
tokenizer_revision
,
)
)
self
.
batch_size
=
batch_size
self
.
_max_gen_toks
=
max_gen_toks
self
.
_max_gen_toks
=
max_gen_toks
@
property
@
property
...
@@ -107,9 +127,18 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -107,9 +127,18 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
def
max_length
(
self
):
def
max_length
(
self
):
if
self
.
_max_length
:
# if max length manually set, return it
if
self
.
_max_length
:
# if max length manually set, return it
return
self
.
_max_length
return
self
.
_max_length
if
hasattr
(
self
.
tokenizer
,
"model_max_length"
):
if
self
.
data_parallel_size
<=
1
:
return
self
.
tokenizer
.
model_max_length
return
self
.
model
.
llm_engine
.
model_config
.
max_model_len
return
self
.
_DEFAULT_MAX_LENGTH
else
:
seqlen_config_attrs
=
(
"n_positions"
,
"max_position_embeddings"
,
"n_ctx"
)
for
attr
in
seqlen_config_attrs
:
if
hasattr
(
self
.
_config
,
attr
):
return
getattr
(
self
.
_config
,
attr
)
if
hasattr
(
self
.
tokenizer
,
"model_max_length"
):
if
self
.
tokenizer
.
model_max_length
==
1000000000000000019884624838656
:
return
self
.
_DEFAULT_MAX_LENGTH
return
self
.
tokenizer
.
model_max_length
return
self
.
_DEFAULT_MAX_LENGTH
@
property
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
):
...
@@ -155,13 +184,13 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -155,13 +184,13 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
temperature
=
0
,
prompt_logprobs
=
2
,
max_tokens
=
1
temperature
=
0
,
prompt_logprobs
=
2
,
max_tokens
=
1
)
)
if
self
.
data_parallel_size
>
1
:
if
self
.
data_parallel_size
>
1
:
requests
=
[
requests
=
[
list
(
x
)
for
x
in
divide
(
requests
,
self
.
data_parallel_size
)]
list
(
x
)
for
x
in
utils
.
divide
(
requests
,
self
.
data_parallel_size
)
]
inputs
=
[(
self
.
model_args
,
sampling_params
,
req
)
for
req
in
requests
]
inputs
=
[(
self
.
model_args
,
sampling_params
,
req
)
for
req
in
requests
]
with
Pool
(
self
.
data_parallel_size
)
as
pool
:
with
Pool
(
self
.
data_parallel_size
)
as
pool
:
results
=
pool
.
starmap
(
run_inference_one_model
,
inputs
)
results
=
pool
.
starmap
(
run_inference_one_model
,
inputs
)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
ray
.
shutdown
()
# flatten results
# flatten results
return
[
item
for
sublist
in
results
for
item
in
sublist
]
return
[
item
for
sublist
in
results
for
item
in
sublist
]
...
@@ -170,7 +199,6 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -170,7 +199,6 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
if
self
.
batch_size
==
"auto"
else
False
,
use_tqdm
=
True
if
self
.
batch_size
==
"auto"
else
False
,
)
)
return
outputs
return
outputs
def
_encode_pair
(
def
_encode_pair
(
...
@@ -193,8 +221,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -193,8 +221,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
if
context
==
""
:
if
context
==
""
:
# end of text as context
# end of text as context
context_enc
,
continuation_enc
=
[
self
.
eot_token_id
],
self
.
tok_encode
(
context_enc
,
continuation_enc
=
(
continuation
[
self
.
eot_token_id
],
self
.
tok_encode
(
continuation
),
)
)
else
:
else
:
context_enc
,
continuation_enc
=
self
.
_encode_pair
(
context
,
continuation
)
context_enc
,
continuation_enc
=
self
.
_encode_pair
(
context
,
continuation
)
...
@@ -209,8 +238,8 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -209,8 +238,8 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
]):
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
]):
rolling_token_windows
=
list
(
rolling_token_windows
=
list
(
map
(
map
(
utils
.
make_disjoint_window
,
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
-
1
,
max_seq_len
=
self
.
max_length
-
1
,
...
@@ -233,8 +262,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -233,8 +262,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
return
loglikelihoods
return
loglikelihoods
def
generate_until
(
self
,
requests
:
List
[
Instance
])
->
List
[
str
]:
def
generate_until
(
self
,
requests
:
List
[
Instance
])
->
List
[
str
]:
res
=
defaultdict
(
list
)
res
=
[]
re_ords
=
{}
# batch tokenize contexts
# batch tokenize contexts
context
,
all_gen_kwargs
=
zip
(
*
(
req
.
args
for
req
in
requests
))
context
,
all_gen_kwargs
=
zip
(
*
(
req
.
args
for
req
in
requests
))
...
@@ -250,84 +278,73 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -250,84 +278,73 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
# padded context length. this is useful to simplify the batching logic and more importantly to make
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
# - any OOMs will happen right away rather than near the end
return
-
len
(
_requests
[
0
][
1
]),
tuple
(
_requests
[
0
][
1
])
return
-
len
(
_requests
[
0
][
1
]),
_requests
[
0
][
0
]
# we group requests by their generation_kwargs,
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
# in the same batch.
grouper
=
utils
.
Grouper
(
requests
,
lambda
x
:
str
(
x
[
1
])
)
re_ords
=
Collator
(
requests
,
_collate_gen
,
grouping
=
True
)
for
key
,
reqs
in
grouper
.
get_grouped
().
items
():
chunks
=
re_ords
.
get_batched
(
# within each set of reqs for given kwargs, we reorder by token length, descending.
n
=
int
(
self
.
batch_size
)
if
self
.
batch_size
!=
"auto"
else
0
,
batch_fn
=
None
re_ords
[
key
]
=
utils
.
Reorderer
(
requests
,
_collate_gen
)
)
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
self
.
rank
!=
0
))
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
self
.
rank
!=
0
))
# for each different set of kwargs, we execute all requests, by batch.
# for each different set of kwargs, we execute all requests, by batch.
for
key
,
re_ord
in
re_ords
.
items
():
for
chunk
in
chunks
:
chunks
=
utils
.
chunks
(
context_and_encoding
,
all_gen_kwargs
=
zip
(
*
chunk
)
re_ord
.
get_reordered
(),
context
,
context_encoding
=
zip
(
*
context_and_encoding
)
n
=
int
(
self
.
batch_size
)
if
self
.
batch_size
!=
"auto"
else
0
,
# we assume all gen kwargs in the batch are the same
fn
=
None
,
# this is safe to assume because the `grouper` object ensures it.
)
gen_kwargs
=
all_gen_kwargs
[
0
]
for
chunk
in
chunks
:
# unpack our keyword arguments.
context_and_encoding
,
all_gen_kwargs
=
zip
(
*
chunk
)
until
=
None
context
,
context_encoding
=
zip
(
*
context_and_encoding
)
if
isinstance
(
gen_kwargs
,
dict
):
# we assume all gen kwargs in the batch are the same
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
# this is safe to assume because the `grouper` object ensures it.
if
"until"
in
kwargs
.
keys
():
gen_kwargs
=
all_gen_kwargs
[
0
]
until
=
kwargs
.
pop
(
"until"
)
# unpack our keyword arguments.
if
isinstance
(
until
,
str
):
until
=
None
until
=
[
until
]
if
isinstance
(
gen_kwargs
,
dict
):
elif
not
isinstance
(
until
,
list
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
raise
ValueError
(
if
"until"
in
kwargs
.
keys
():
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
until
=
kwargs
.
pop
(
"until"
)
)
if
isinstance
(
until
,
str
):
else
:
until
=
[
until
]
raise
ValueError
(
elif
not
isinstance
(
until
,
list
):
f
"Expected `kwargs` to be of type `dict` but got
{
gen_kwargs
}
"
raise
ValueError
(
f
"Expected `kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
else
:
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
gen_kwargs
}
"
)
if
not
until
:
until
=
[
self
.
tokenizer
.
decode
(
self
.
eot_token_id
)]
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
max_gen_toks
=
self
.
max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len
=
self
.
max_length
-
max_gen_toks
context_encoding
=
[
x
[
-
max_ctx_len
:]
for
x
in
context_encoding
]
# TODO: max_length in kwargs
# perform batched generation
cont
=
self
.
_model_generate
(
requests
=
context_encoding
,
generate
=
True
,
max_tokens
=
max_gen_toks
,
stop
=
until
,
**
kwargs
,
)
)
if
not
until
:
until
=
[
self
.
tokenizer
.
decode
(
self
.
eot_token_id
)]
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
max_gen_toks
=
self
.
max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len
=
self
.
max_length
-
max_gen_toks
context_encoding
=
[
x
[
-
max_ctx_len
:]
for
x
in
context_encoding
]
# perform batched generation
cont
=
self
.
_model_generate
(
requests
=
context_encoding
,
generate
=
True
,
max_tokens
=
max_gen_toks
,
stop
=
until
,
**
kwargs
,
)
# cache generations
# cache generations
for
output
,
context
in
zip
(
cont
,
context
):
for
output
,
context
in
zip
(
cont
,
context
):
generated_text
=
output
.
outputs
[
0
].
text
generated_text
=
output
.
outputs
[
0
].
text
res
[
key
].
append
(
generated_text
)
res
.
append
(
generated_text
)
self
.
cache_hook
.
add_partial
(
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
context
,
gen_kwargs
),
generated_text
"generate_until"
,
(
context
,
gen_kwargs
),
generated_text
)
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
# reorder this group of results back to original unsorted form
res
[
key
]
=
re_ord
.
get_original
(
res
[
key
])
pbar
.
close
()
pbar
.
close
()
# reorder all group of results back to original unsorted form
return
grouper
.
get_original
(
res
)
return
re_ords
.
get_original
(
res
)
def
_loglikelihood_tokens
(
def
_loglikelihood_tokens
(
self
,
self
,
...
@@ -340,16 +357,15 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -340,16 +357,15 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
toks
=
x
[
1
]
+
x
[
2
]
toks
=
x
[
1
]
+
x
[
2
]
return
-
len
(
toks
),
tuple
(
toks
)
return
-
len
(
toks
),
tuple
(
toks
)
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
# Reorder requests by length and batch
re_ord
=
Collator
(
requests
,
sort_fn
=
_collate
)
chunks
=
utils
.
chunks
(
chunks
=
re_ord
.
get_batched
(
re_ord
.
get_reordered
(),
n
=
int
(
self
.
batch_size
)
if
self
.
batch_size
!=
"auto"
else
0
,
batch_fn
=
None
n
=
int
(
self
.
batch_size
)
if
self
.
batch_size
!=
"auto"
else
0
,
fn
=
None
,
)
)
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
disable_tqdm
)
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
disable_tqdm
)
for
chunk
in
chunks
:
for
chunk
in
chunks
:
inps
=
[]
inp
ut
s
=
[]
ctxlens
=
[]
ctxlens
=
[]
for
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
for
cache_key
,
context_enc
,
continuation_enc
in
chunk
:
inp
=
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
)
:]
inp
=
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
)
:]
...
@@ -357,18 +373,18 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -357,18 +373,18 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
(
self
.
max_length
)
0
,
len
(
context_enc
)
+
len
(
continuation_enc
)
-
(
self
.
max_length
)
)
)
inps
.
append
(
inp
)
inp
ut
s
.
append
(
inp
)
ctxlens
.
append
(
ctxlen
)
ctxlens
.
append
(
ctxlen
)
outputs
=
self
.
_model_generate
(
requests
=
inps
,
generate
=
False
)
outputs
=
self
.
_model_generate
(
requests
=
inp
ut
s
,
generate
=
False
)
for
output
,
ctxlen
,
(
cache_key
,
context_enc
,
continuation_enc
)
in
zip
(
for
output
,
ctxlen
,
(
cache_key
,
_
,
_
),
inp
in
zip
(
outputs
,
ctxlens
,
chunk
outputs
,
ctxlens
,
chunk
,
inputs
):
):
answer
=
self
.
_parse_logprobs
(
answer
=
self
.
_parse_logprobs
(
(
context_enc
+
continuation_enc
)
,
tokens
=
inp
,
output
,
outputs
=
output
,
ctxlen
,
ctxlen
=
ctxlen
,
)
)
res
.
append
(
answer
)
res
.
append
(
answer
)
...
@@ -376,7 +392,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -376,7 +392,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
# partial caching
# partial caching
if
cache_key
is
not
None
:
if
cache_key
is
not
None
:
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
pbar
.
close
()
pbar
.
close
()
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
...
@@ -385,9 +401,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -385,9 +401,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
"""Process logprobs and tokens.
"""Process logprobs and tokens.
:param tokens: list
:param tokens: list
Tokens from context+continuations
Input tokens (potentially left-truncated)
:param outputs: RequestOutput
:param outputs: RequestOutput
Contains prompt
Contains prompt
_logprobs
:param ctxlen: int
:param ctxlen: int
Length of context (so we can slice them away and only keep the predictions)
Length of context (so we can slice them away and only keep the predictions)
:return:
:return:
...
@@ -397,11 +413,11 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
...
@@ -397,11 +413,11 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
Whether argmax matches given continuation exactly
Whether argmax matches given continuation exactly
"""
"""
# prompt_logprobs
= [
None
, {}*len(context-1)]
#
The first entry of
prompt_logprobs
is
None
because the model has no previous tokens to condition on.
continuation_logprobs_dicts
=
outputs
.
prompt_logprobs
continuation_logprobs_dicts
=
outputs
.
prompt_logprobs
# Calculate continuation_logprobs
# Calculate continuation_logprobs
# assume ctxlen always > 1
# assume ctxlen always >
=
1
continuation_logprobs
=
sum
(
continuation_logprobs
=
sum
(
logprob_dict
.
get
(
token
)
logprob_dict
.
get
(
token
)
for
token
,
logprob_dict
in
zip
(
for
token
,
logprob_dict
in
zip
(
...
...
lm_eval/prompts/__init__.py
View file @
0d1ef037
...
@@ -69,7 +69,6 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None
...
@@ -69,7 +69,6 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None
def
load_prompt_list
(
def
load_prompt_list
(
use_prompt
:
str
,
dataset_name
=
None
,
subset_name
=
None
,
yaml_path
=
None
,
**
kwargs
use_prompt
:
str
,
dataset_name
=
None
,
subset_name
=
None
,
yaml_path
=
None
,
**
kwargs
):
):
category_name
,
prompt_name
=
use_prompt
.
split
(
":"
)
category_name
,
prompt_name
=
use_prompt
.
split
(
":"
)
if
category_name
==
"promptsource"
:
if
category_name
==
"promptsource"
:
...
@@ -113,7 +112,6 @@ class PromptString:
...
@@ -113,7 +112,6 @@ class PromptString:
self
.
prompt_string
=
prompt_string
self
.
prompt_string
=
prompt_string
def
apply
(
self
,
doc
):
def
apply
(
self
,
doc
):
doc_to_text
=
self
.
prompt_string
[
"doc_to_text"
]
doc_to_text
=
self
.
prompt_string
[
"doc_to_text"
]
doc_to_target
=
self
.
prompt_string
[
"doc_to_target"
]
doc_to_target
=
self
.
prompt_string
[
"doc_to_target"
]
...
...
lm_eval/tasks/__init__.py
View file @
0d1ef037
...
@@ -61,11 +61,27 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
...
@@ -61,11 +61,27 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
task_list
=
[
task
for
task
in
all_task_list
if
type
(
task
)
==
str
]
task_list
=
[
task
for
task
in
all_task_list
if
type
(
task
)
==
str
]
for
task_config
in
config_list
:
for
task_config
in
config_list
:
base_config
=
{}
task_name_config
=
{}
if
"task"
in
task_config
:
task_name
=
task_config
[
"task"
]
if
task_name
in
ALL_TASKS
:
task_obj
=
get_task_dict
(
task_name
)[
task_name
]
if
type
(
task_obj
)
==
tuple
:
_
,
task_obj
=
task_obj
if
task_obj
is
not
None
:
base_config
=
task_obj
.
_config
.
to_dict
()
task_name_config
[
"task"
]
=
f
"
{
group
}
_
{
task_name
}
"
task_config
=
utils
.
load_yaml_config
(
yaml_path
,
task_config
)
task_config
=
utils
.
load_yaml_config
(
yaml_path
,
task_config
)
var_configs
=
check_prompt_config
(
var_configs
=
check_prompt_config
(
{
{
**
base_config
,
**
task_config
,
**
task_config
,
**
{
"group"
:
group
},
**
{
"group"
:
group
},
**
task_name_config
,
},
},
yaml_path
=
os
.
path
.
dirname
(
yaml_path
),
yaml_path
=
os
.
path
.
dirname
(
yaml_path
),
)
)
...
@@ -131,7 +147,10 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
...
@@ -131,7 +147,10 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
"""
"""
Calling this function
Calling this function
"""
"""
for
root
,
subdirs
,
file_list
in
reversed
(
list
(
os
.
walk
(
task_dir
))):
# Track whether any tasks failed during loading
import_fail
=
False
for
root
,
subdirs
,
file_list
in
os
.
walk
(
task_dir
):
# if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
# if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for
f
in
file_list
:
for
f
in
file_list
:
if
f
.
endswith
(
".yaml"
):
if
f
.
endswith
(
".yaml"
):
...
@@ -155,20 +174,27 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
...
@@ -155,20 +174,27 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
# Log this silently and show it only when
# Log this silently and show it only when
# the user defines the appropriate verbosity.
# the user defines the appropriate verbosity.
except
ModuleNotFoundError
as
e
:
except
(
ImportError
,
ModuleNotFoundError
)
as
e
:
import_fail
=
True
eval_logger
.
debug
(
eval_logger
.
debug
(
f
"
{
yaml_path
}
:
{
e
}
. Config will not be added to registry."
f
"
{
yaml_path
}
:
{
e
}
. Config will not be added to registry."
)
)
except
Exception
as
error
:
except
Exception
as
error
:
import
traceback
import
traceback
eval_logger
.
debu
g
(
eval_logger
.
warnin
g
(
"
Failed to
load config in
\n
"
"
Unexpected error
load
ing
config in
\n
"
f
"
{
yaml_path
}
\n
"
f
"
{
yaml_path
}
\n
"
" Config will not be added to registry
\n
"
" Config will not be added to registry
\n
"
f
" Error:
{
error
}
\n
"
f
" Error:
{
error
}
\n
"
f
" Traceback:
{
traceback
.
format_exc
()
}
"
f
" Traceback:
{
traceback
.
format_exc
()
}
"
)
)
if
import_fail
:
eval_logger
.
warning
(
"Some tasks could not be loaded due to missing dependencies."
" Run with `--verbosity DEBUG` for full details."
)
return
0
return
0
...
@@ -180,7 +206,6 @@ def include_path(task_dir):
...
@@ -180,7 +206,6 @@ def include_path(task_dir):
def
initialize_tasks
(
verbosity
=
"INFO"
):
def
initialize_tasks
(
verbosity
=
"INFO"
):
eval_logger
.
setLevel
(
getattr
(
logging
,
f
"
{
verbosity
}
"
))
eval_logger
.
setLevel
(
getattr
(
logging
,
f
"
{
verbosity
}
"
))
task_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/"
task_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/"
...
...
lm_eval/tasks/anli/anli_r1.yaml
View file @
0d1ef037
...
@@ -23,4 +23,4 @@ metric_list:
...
@@ -23,4 +23,4 @@ metric_list:
aggregation
:
mean
aggregation
:
mean
higher_is_better
:
true
higher_is_better
:
true
metadata
:
metadata
:
-
version
:
1.0
version
:
1.0
lm_eval/tasks/arc/arc_easy.yaml
View file @
0d1ef037
...
@@ -20,4 +20,4 @@ metric_list:
...
@@ -20,4 +20,4 @@ metric_list:
aggregation
:
mean
aggregation
:
mean
higher_is_better
:
true
higher_is_better
:
true
metadata
:
metadata
:
-
version
:
1.0
version
:
1.0
lm_eval/tasks/arithmetic/arithmetic_1dc.yaml
View file @
0d1ef037
...
@@ -13,4 +13,4 @@ metric_list:
...
@@ -13,4 +13,4 @@ metric_list:
aggregation
:
mean
aggregation
:
mean
higher_is_better
:
true
higher_is_better
:
true
metadata
:
metadata
:
-
version
:
1.0
version
:
1.0
lm_eval/tasks/asdiv/default.yaml
View file @
0d1ef037
...
@@ -11,4 +11,4 @@ metric_list:
...
@@ -11,4 +11,4 @@ metric_list:
aggregation
:
mean
aggregation
:
mean
higher_is_better
:
true
higher_is_better
:
true
metadata
:
metadata
:
-
version
:
1.0
version
:
1.0
lm_eval/tasks/babi/babi.yaml
View file @
0d1ef037
...
@@ -17,4 +17,4 @@ metric_list:
...
@@ -17,4 +17,4 @@ metric_list:
aggregation
:
mean
aggregation
:
mean
higher_is_better
:
true
higher_is_better
:
true
metadata
:
metadata
:
-
version
:
0
.0
version
:
1
.0
lm_eval/tasks/bbh/_generate_configs.py
View file @
0d1ef037
...
@@ -24,7 +24,6 @@ def parse_args():
...
@@ -24,7 +24,6 @@ def parse_args():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
args
=
parse_args
()
args
=
parse_args
()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
# get filename of base_yaml so we can `"include": ` it in our other YAMLs.
...
@@ -37,7 +36,6 @@ if __name__ == "__main__":
...
@@ -37,7 +36,6 @@ if __name__ == "__main__":
dataset_path
=
"lukaemon/bbh"
dataset_path
=
"lukaemon/bbh"
for
task
in
tqdm
(
datasets
.
get_dataset_infos
(
dataset_path
).
keys
()):
for
task
in
tqdm
(
datasets
.
get_dataset_infos
(
dataset_path
).
keys
()):
resp
=
requests
.
get
(
resp
=
requests
.
get
(
f
"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/cot-prompts/
{
task
}
.txt"
f
"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/cot-prompts/
{
task
}
.txt"
).
content
.
decode
(
"utf-8"
)
).
content
.
decode
(
"utf-8"
)
...
...
lm_eval/tasks/bbh/cot_fewshot/_cot_fewshot_template_yaml
View file @
0d1ef037
...
@@ -27,4 +27,4 @@ filter_list:
...
@@ -27,4 +27,4 @@ filter_list:
- function: "take_first"
- function: "take_first"
num_fewshot: 0
num_fewshot: 0
metadata:
metadata:
-
version:
1
.0
version:
2
.0
Prev
1
2
3
4
5
6
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment