Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
4f29f7cc
Commit
4f29f7cc
authored
Jun 23, 2023
by
haileyschoelkopf
Browse files
merge big-refactor into fix branch
parents
f832c776
9dea125b
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
136 additions
and
29 deletions
+136
-29
lm_eval/tasks/super_glue/boolq/seq2seq.yaml
lm_eval/tasks/super_glue/boolq/seq2seq.yaml
+18
-0
lm_eval/tasks/super_glue/cb/default.yaml
lm_eval/tasks/super_glue/cb/default.yaml
+2
-1
lm_eval/utils.py
lm_eval/utils.py
+112
-0
main.py
main.py
+1
-10
scripts/write_out.py
scripts/write_out.py
+0
-14
setup.py
setup.py
+2
-0
tests/test_description.py
tests/test_description.py
+1
-2
tests/test_evaluator.py
tests/test_evaluator.py
+0
-2
No files found.
lm_eval/tasks/super_glue/boolq/seq2seq.yaml
0 → 100644
View file @
4f29f7cc
group
:
-
super-glue-lm-eval-v1
task
:
"
boolq-seq2seq"
dataset_path
:
super_glue
dataset_name
:
boolq
output_type
:
greedy_until
training_split
:
train
validation_split
:
validation
doc_to_text
:
"
{{passage}}
\n
Question:
{{question}}
\n
Answer:"
doc_to_target
:
"
{{answer_choices[label]}}"
gold_alias
:
"
{{label}}"
# this will be cast to an int.
template_aliases
:
"
{%
set
answer_choices
=
['no',
'yes']
%}"
metric_list
:
-
metric
:
exact_match
aggregation
:
mean
higher_is_better
:
true
ignore_case
:
true
ignore_punctuation
:
true
lm_eval/tasks/super_glue/cb/default.yaml
View file @
4f29f7cc
...
...
@@ -7,7 +7,8 @@ output_type: multiple_choice
training_split
:
train
validation_split
:
validation
doc_to_text
:
"
{{premise}}
\n
Question:
{{hypothesis}}.
True,
False,
or
Neither?
\n
Answer:"
doc_to_target
:
"
{{label}}"
# this will be cast to an int.
doc_to_target
:
"
{{answer_choices[labe]}}"
gold_alias
:
"
{{label}}"
# this will be cast to an int.
template_aliases
:
"
{%
set
answer_choices
=
['True',
'False',
'Neither']
%}"
metric_list
:
-
metric
:
acc
...
...
lm_eval/utils.py
View file @
4f29f7cc
...
...
@@ -14,6 +14,7 @@ from typing import List, Union
import
gc
import
torch
import
transformers
from
omegaconf
import
OmegaConf
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
...
...
@@ -391,7 +392,13 @@ def load_yaml_config(yaml_path):
return
yaml_config
def
regex_replace
(
string
,
pattern
,
repl
,
count
=
0
):
"""Implements the `re.sub` function as a custom Jinja filter."""
return
re
.
sub
(
pattern
,
repl
,
string
,
count
=
count
)
env
=
Environment
(
loader
=
BaseLoader
,
undefined
=
StrictUndefined
)
env
.
filters
[
"regex_replace"
]
=
regex_replace
def
apply_template
(
template
,
doc
):
...
...
@@ -408,6 +415,111 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return
islice
(
raw_iterator
,
rank
,
limit
,
world_size
)
def
pad_and_concat
(
max_length
:
int
,
tensors
:
List
[
torch
.
Tensor
],
padding_side
=
"right"
):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
assert
(
padding_side
==
"left"
or
padding_side
==
"right"
),
f
"Unrecognized padding type: '
{
padding_side
}
' not 'left' or 'right'"
for
i
,
tensor
in
enumerate
(
tensors
):
tensor_len
=
tensor
.
shape
[
0
]
if
tensor_len
<
max_length
:
if
padding_side
==
"right"
:
# right-pad
tensors
[
i
]
=
torch
.
cat
(
[
tensor
,
# [seq]
torch
.
zeros
(
max_length
-
tensor_len
,
dtype
=
torch
.
long
,
device
=
tensor
.
device
,
),
# [padding_length - seq]
],
dim
=
0
,
).
unsqueeze
(
0
)
else
:
# left-pad
tensors
[
i
]
=
torch
.
cat
(
[
torch
.
zeros
(
max_length
-
tensor_len
,
dtype
=
torch
.
long
,
device
=
tensor
.
device
,
),
# [padding_length - seq]
tensor
,
# [seq]
],
dim
=
0
,
).
unsqueeze
(
0
)
else
:
tensors
[
i
]
=
tensor
.
unsqueeze
(
0
)
return
torch
.
cat
(
tensors
,
dim
=
0
)
def
clear_torch_cache
():
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
get_dtype
(
dtype
:
Union
[
str
,
torch
.
dtype
])
->
torch
.
dtype
:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if
isinstance
(
dtype
,
str
)
and
dtype
!=
"auto"
:
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype
=
getattr
(
torch
,
dtype
)
else
:
_torch_dtype
=
dtype
return
_torch_dtype
# Multi-token stopping criteria
class
MultiTokenEOSCriteria
(
transformers
.
StoppingCriteria
):
"""Criteria to stop on the specified multi-token sequence."""
def
__init__
(
self
,
sequence
:
str
,
tokenizer
:
transformers
.
PreTrainedTokenizer
,
initial_decoder_input_length
:
int
,
batch_size
:
int
,
):
self
.
initial_decoder_input_length
=
initial_decoder_input_length
self
.
done_tracker
=
[
False
]
*
batch_size
self
.
sequence
=
sequence
self
.
sequence_ids
=
tokenizer
.
encode
(
sequence
,
add_special_tokens
=
False
)
self
.
sequence_id_len
=
len
(
self
.
sequence_ids
)
self
.
tokenizer
=
tokenizer
def
__call__
(
self
,
input_ids
,
scores
,
**
kwargs
)
->
bool
:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch
=
input_ids
[:,
self
.
initial_decoder_input_length
:][
:,
-
self
.
sequence_id_len
:
]
lookback_tokens_batch
=
self
.
tokenizer
.
batch_decode
(
lookback_ids_batch
)
for
i
,
done
in
enumerate
(
self
.
done_tracker
):
if
not
done
:
self
.
done_tracker
[
i
]
=
self
.
sequence
in
lookback_tokens_batch
[
i
]
return
False
not
in
self
.
done_tracker
def
stop_sequences_criteria
(
tokenizer
:
transformers
.
PreTrainedTokenizer
,
stop_sequences
:
List
[
str
],
initial_decoder_input_length
:
int
,
batch_size
:
int
,
)
->
transformers
.
StoppingCriteriaList
:
return
transformers
.
StoppingCriteriaList
(
[
*
[
MultiTokenEOSCriteria
(
sequence
,
tokenizer
,
initial_decoder_input_length
,
batch_size
)
for
sequence
in
stop_sequences
],
]
)
main.py
View file @
4f29f7cc
...
...
@@ -41,7 +41,6 @@ def parse_args():
parser
.
add_argument
(
"--data_sampling"
,
type
=
float
,
default
=
None
)
parser
.
add_argument
(
"--no_cache"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--decontamination_ngrams_path"
,
default
=
None
)
parser
.
add_argument
(
"--description_dict_path"
,
default
=
None
)
parser
.
add_argument
(
"--check_integrity"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--write_out"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--output_base_path"
,
type
=
str
,
default
=
None
)
...
...
@@ -78,12 +77,6 @@ def main():
eval_logger
.
info
(
f
"Selected Tasks:
{
task_names
}
"
)
# TODO: description_dict?
# description_dict = {}
# if args.description_dict_path:
# with open(args.description_dict_path, "r") as f:
# description_dict = json.load(f)
results
=
evaluator
.
simple_evaluate
(
model
=
args
.
model
,
model_args
=
args
.
model_args
,
...
...
@@ -94,7 +87,6 @@ def main():
device
=
args
.
device
,
no_cache
=
args
.
no_cache
,
limit
=
args
.
limit
,
# description_dict=description_dict,
decontamination_ngrams_path
=
args
.
decontamination_ngrams_path
,
check_integrity
=
args
.
check_integrity
,
write_out
=
args
.
write_out
,
...
...
@@ -103,8 +95,7 @@ def main():
if
results
is
not
None
:
samples
=
results
.
pop
(
"samples"
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
)
dumped
=
json
.
dumps
(
results
,
indent
=
2
,
default
=
lambda
o
:
str
(
o
))
print
(
dumped
)
batch_sizes
=
","
.
join
(
map
(
str
,
results
[
"config"
][
"batch_sizes"
]))
...
...
scripts/write_out.py
View file @
4f29f7cc
...
...
@@ -13,12 +13,10 @@ def parse_args():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--output_base_path"
,
required
=
True
)
parser
.
add_argument
(
"--tasks"
,
default
=
"all_tasks"
)
parser
.
add_argument
(
"--provide_description"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--sets"
,
type
=
str
,
default
=
"val"
)
# example: val,test
parser
.
add_argument
(
"--num_fewshot"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
)
parser
.
add_argument
(
"--num_examples"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--description_dict_path"
,
default
=
None
)
return
parser
.
parse_args
()
...
...
@@ -32,11 +30,6 @@ def main():
task_names
=
args
.
tasks
.
split
(
","
)
task_dict
=
tasks
.
get_task_dict
(
task_names
)
# description_dict = {}
# if args.description_dict_path:
# with open(args.description_dict_path, "r") as f:
# description_dict = json.load(f)
os
.
makedirs
(
args
.
output_base_path
,
exist_ok
=
True
)
for
task_name
,
task
in
task_dict
.
items
():
rnd
=
random
.
Random
()
...
...
@@ -55,12 +48,6 @@ def main():
docs
=
join_iters
(
iters
)
# description = (
# description_dict[task_name]
# if description_dict and task_name in description_dict
# else ""
# )
with
open
(
os
.
path
.
join
(
args
.
output_base_path
,
task_name
),
"w"
)
as
f
:
for
i
,
doc
in
(
zip
(
range
(
args
.
num_examples
),
docs
)
...
...
@@ -72,7 +59,6 @@ def main():
doc
=
doc
,
num_fewshot
=
args
.
num_fewshot
,
rnd
=
rnd
,
# description=description,
)
f
.
write
(
ctx
+
"
\n
"
)
...
...
setup.py
View file @
4f29f7cc
...
...
@@ -28,7 +28,9 @@ setuptools.setup(
python_requires
=
">=3.9"
,
install_requires
=
[
"accelerate>=0.18.0"
,
"evaluate"
,
"datasets>=2.0.0"
,
"evaluate>=0.4.0"
,
"jsonlines"
,
"numexpr"
,
"openai>=0.6.4"
,
...
...
tests/test_description
_dict
.py
→
tests/test_description.py
View file @
4f29f7cc
...
...
@@ -3,7 +3,7 @@ import lm_eval.tasks
import
lm_eval.models
def
test_description
_dict
():
def
test_description
():
seed
=
42
num_examples
=
1
task_names
=
[
"arc_challenge"
,
"lambada"
]
...
...
@@ -41,6 +41,5 @@ def test_description_dict():
doc
=
doc
,
num_fewshot
=
1
,
rnd
=
rnd
,
description
=
description
,
)
assert
description
in
ctx
tests/test_evaluator.py
View file @
4f29f7cc
...
...
@@ -61,7 +61,6 @@ def test_evaluator(taskname, task_class):
num_fewshot
=
0
,
limit
=
limit
,
bootstrap_iters
=
10
,
description_dict
=
None
,
)
e2
=
evaluator
.
evaluate
(
lm
=
lm
,
...
...
@@ -69,7 +68,6 @@ def test_evaluator(taskname, task_class):
num_fewshot
=
0
,
limit
=
limit
,
bootstrap_iters
=
10
,
description_dict
=
None
,
)
# check that caching is working
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment