Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1576e99e
Commit
1576e99e
authored
Jan 13, 2023
by
Albert Jiang
Browse files
allow task-specific descriptions to be passed to
request construction
parent
99fdec0c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
113 additions
and
107 deletions
+113
-107
lm_eval/base.py
lm_eval/base.py
+23
-44
lm_eval/evaluator.py
lm_eval/evaluator.py
+12
-2
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+17
-10
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+0
-1
lm_eval/tasks/hendrycks_math.py
lm_eval/tasks/hendrycks_math.py
+61
-50
No files found.
lm_eval/base.py
View file @
1576e99e
...
...
@@ -7,6 +7,7 @@ import os
import
json
import
hashlib
import
datasets
import
inspect
from
sqlitedict
import
SqliteDict
from
tqdm
import
tqdm
import
torch
...
...
@@ -329,7 +330,7 @@ class BaseLM(LM):
return
re_ord
.
get_original
(
res
)
def
multiple_temperature_sample_until
(
self
,
requests
,
k
=
32
,
temperature
=
0.3
):
def
generate
(
self
,
requests
):
res
=
[]
def
_collate
(
x
):
...
...
@@ -338,21 +339,33 @@ class BaseLM(LM):
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
context
,
until
in
tqdm
(
re_ord
.
get_reordered
()):
for
request
in
tqdm
(
re_ord
.
get_reordered
()):
if
len
(
request
)
==
2
:
# Unpack greedy sample request
context
,
until
,
=
request
k
,
temperature
=
1
,
0.
_model_generate_kwargs
=
{}
elif
len
(
request
)
==
4
:
# Unpack temperature sample request
context
,
until
,
k
,
temperature
=
request
for
key
in
[
"k"
,
"temperature"
]:
assert
key
in
inspect
.
getfullargspec
(
self
.
_model_generate
).
args
,
\
f
"Model generation parameter '
{
key
}
' not accepted as an argument for _model_generate"
_model_generate_kwargs
=
{
"k"
:
k
,
"temperature"
:
temperature
}
else
:
raise
AssertionError
if
isinstance
(
until
,
str
):
until
=
[
until
]
(
primary_until
,)
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
).
to
(
self
.
device
)
assert
context_enc
.
shape
[
0
]
==
1
context_enc
=
context_enc
.
expand
(
k
,
context_enc
.
shape
[
1
])
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
,
temperature
=
temperature
**
_model_generate_kwargs
)
generated_tokens
=
cont
[:,
context_enc
.
shape
[
1
]:]
...
...
@@ -361,50 +374,16 @@ class BaseLM(LM):
s
=
[
candidate
.
split
(
term
)[
0
]
for
candidate
in
s
]
# partial caching
self
.
cache_hook
.
add_partial
(
"multiple_temperature_sample_until"
,
(
context
,
until
,
k
,
temperature
),
s
)
self
.
cache_hook
.
add_partial
(
"generate"
,
(
context
,
until
,
k
,
temperature
),
s
)
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
# TODO: implement fully general `until` that handles until that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
res
=
[]
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
return
len
(
toks
),
x
[
0
]
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
context
,
until
in
tqdm
(
re_ord
.
get_reordered
()):
if
isinstance
(
until
,
str
):
until
=
[
until
]
(
primary_until
,)
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]
:])
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
return
self
.
generate
(
requests
)
class
Task
(
abc
.
ABC
):
...
...
@@ -881,7 +860,7 @@ class CachingLM:
REQUEST_RETURN_LENGTHS
=
{
"loglikelihood"
:
2
,
"greedy_until"
:
None
,
"
multiple_temperature_sample_until
"
:
None
,
"
generate
"
:
None
,
"loglikelihood_rolling"
:
None
,
}
...
...
lm_eval/evaluator.py
View file @
1576e99e
...
...
@@ -2,6 +2,7 @@ import collections
import
itertools
import
numpy
as
np
import
random
import
inspect
import
lm_eval.metrics
import
lm_eval.models
import
lm_eval.tasks
...
...
@@ -177,6 +178,7 @@ def evaluate(
docs
=
{}
docs_for_decontamination
=
collections
.
defaultdict
(
list
)
task_to_description
=
{}
# get lists of each type of request
for
task_name
,
task
in
task_dict_items
:
...
...
@@ -203,6 +205,7 @@ def evaluate(
if
description_dict
and
task_name
in
description_dict
else
""
)
task_to_description
[
task_name
]
=
description
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
...
...
@@ -215,7 +218,10 @@ def evaluate(
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
"description"
in
inspect
.
getfullargspec
(
task
.
construct_requests
).
args
:
reqs
=
task
.
construct_requests
(
doc
,
ctx
,
description
=
description
)
else
:
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
reqs
=
[
reqs
]
for
i
,
req
in
enumerate
(
reqs
):
...
...
@@ -262,7 +268,11 @@ def evaluate(
task
=
task_dict
[
task_name
]
doc
=
docs
[(
task_name
,
doc_id
)]
metrics
=
task
.
process_results
(
doc
,
requests
)
# be backward compatible with tasks that do not allow description_dict in process_results
if
"description"
in
inspect
.
getfullargspec
(
task
.
process_results
).
args
:
metrics
=
task
.
process_results
(
doc
,
requests
,
task_to_description
[
task_name
])
else
:
metrics
=
task
.
process_results
(
doc
,
requests
)
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_name
,
metric
)].
append
(
value
)
...
...
lm_eval/models/gpt2.py
View file @
1576e99e
...
...
@@ -121,16 +121,23 @@ class HFLM(BaseLM):
with
torch
.
no_grad
():
return
self
.
gpt2
(
inps
)[
0
][:,
:,
:
50257
]
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
,
temperature
=
0.
):
assert
temperature
>=
0.
if
temperature
==
0.
:
return
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
)
else
:
return
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
True
,
temperature
=
temperature
)
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
,
k
=
1
,
temperature
=
0.
):
assert
(
isinstance
(
k
,
int
)
and
k
>=
1
),
f
"Incorrect number of candidates to generate:
{
k
}
"
assert
temperature
>=
0.
,
f
"Negative sampling temperature:
{
temperature
}
"
# Whether to sample or to decode greedily
do_sample
=
(
temperature
!=
0.
)
if
not
do_sample
:
# If decoding greedily, only sample once
assert
k
==
1
,
f
"Decoding greedily but
{
k
}
generations"
if
k
>
1
:
context
=
context
.
expand
(
k
,
context
.
shape
[
1
])
return
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
do_sample
,
temperature
=
temperature
)
# for backwards compatibility
...
...
lm_eval/tasks/__init__.py
View file @
1576e99e
...
...
@@ -157,7 +157,6 @@ TASK_REGISTRY = {
"mutual_plus"
:
mutual
.
MuTualPlus
,
# math
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_algebra_maj@k"
:
hendrycks_math
.
MathAlgebraMaj
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
"math_geometry"
:
hendrycks_math
.
MathGeometry
,
"math_intermediate_algebra"
:
hendrycks_math
.
MathIntermediateAlgebra
,
...
...
lm_eval/tasks/hendrycks_math.py
View file @
1576e99e
...
...
@@ -27,6 +27,8 @@ _CITATION = """
class
Math
(
Task
):
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
hendrycks_math
.
hendrycks_math
)
DATASET_NAME
=
None
MAJORITY_VOTING
=
"majority_voting"
SAMPLING_TEMPERATURE
=
"sampling_temperature"
def
has_training_docs
(
self
):
return
True
...
...
@@ -62,21 +64,72 @@ class Math(Task):
def
doc_to_target
(
self
,
doc
):
return
" "
+
doc
[
"solution"
]
def
construct_requests
(
self
,
doc
,
ctx
):
return
rf
.
greedy_until
(
ctx
,
[
"
\n
"
])
def
parse_description
(
self
,
description
):
"""description is a string with comma-separated key=value tuples
e.g.:
"majority_voting=32,sampling_temperature=1.0"
"""
parsed_dict
=
{}
for
term
in
description
.
split
(
","
):
if
not
term
.
strip
():
continue
key
,
value
=
term
.
split
(
"="
)
parsed_dict
[
key
]
=
value
return
parsed_dict
def
construct_requests
(
self
,
doc
,
ctx
,
description
=
""
):
if
not
description
.
strip
():
return
rf
.
generate
(
ctx
,
[
"
\n
"
])
parsed_description
=
self
.
parse_description
(
description
=
description
)
majority_voting_value
=
int
(
parsed_description
.
get
(
self
.
MAJORITY_VOTING
,
1
))
sampling_temperature_value
=
float
(
parsed_description
.
get
(
self
.
SAMPLING_TEMPERATURE
,
1.0
))
return
rf
.
generate
(
ctx
,
[
"
\n
"
],
majority_voting_value
,
sampling_temperature_value
)
def
get_pure_answer
(
self
,
candidate
):
indices
=
[
pos
for
pos
,
char
in
enumerate
(
candidate
)
if
char
==
"$"
]
if
len
(
indices
)
<=
1
:
return
candidate
return
candidate
[
indices
[
0
]
+
1
:
indices
[
-
1
]]
def
process_results
(
self
,
doc
,
results
):
def
majority_vote
(
self
,
candidates
):
answers
=
[]
for
candidate
in
candidates
:
answer
=
self
.
get_pure_answer
(
candidate
)
try
:
answer
=
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
answer
))
except
:
answer
=
None
answers
.
append
(
answer
)
answer_votes
=
{}
for
answer
in
answers
:
answer_votes
[
answer
]
=
answer_votes
.
get
(
answer
,
0
)
+
1
max_vote
=
0
elected
=
None
for
answer
,
vote
in
answer_votes
.
items
():
if
vote
>
max_vote
and
answer
is
not
None
:
elected
=
answer
max_vote
=
vote
return
elected
def
process_results
(
self
,
doc
,
results
,
description
=
""
):
retval
=
0
indices
=
[
pos
for
pos
,
char
in
enumerate
(
results
[
0
])
if
char
==
"$"
]
if
len
(
indices
)
<=
1
:
answer
=
results
[
0
]
else
:
answer
=
results
[
0
][
indices
[
0
]
+
1
:
indices
[
-
1
]]
if
description
==
""
:
answer
=
self
.
get_pure_answer
(
results
[
0
])
elif
self
.
MAJORITY_VOTING
in
self
.
parse_description
(
description
):
answer
=
self
.
majority_vote
(
results
[
0
])
else
:
raise
AssertionError
if
self
.
is_equiv
(
answer
,
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
doc
[
"solution"
]))
):
retval
=
1
return
{
"acc"
:
retval
}
def
aggregation
(
self
):
...
...
@@ -286,48 +339,6 @@ class MathAlgebra(Math):
DATASET_NAME
=
"algebra"
class
MathAlgebraMaj
(
Math
):
VERSION
=
1
DATASET_NAME
=
"algebra"
def
construct_requests
(
self
,
doc
,
ctx
):
return
rf
.
multiple_temperature_sample_until
(
ctx
,
[
"
\n
"
])
def
process_results
(
self
,
doc
,
results
):
retval
=
0
candidates
=
results
[
0
]
answers
=
[]
for
candidate
in
candidates
:
indices
=
[
pos
for
pos
,
char
in
enumerate
(
candidate
)
if
char
==
"$"
]
if
len
(
indices
)
<=
1
:
answer
=
candidate
else
:
answer
=
candidate
[
indices
[
0
]
+
1
:
indices
[
-
1
]]
try
:
answer
=
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
answer
))
except
:
answer
=
None
answers
.
append
(
answer
)
answer_votes
=
{}
for
answer
in
answers
:
answer_votes
[
answer
]
=
answer_votes
.
get
(
answer
,
0
)
+
1
max_vote
=
0
elected
=
None
for
answer
,
vote
in
answer_votes
.
items
():
if
vote
>
max_vote
and
answer
is
not
None
:
elected
=
answer
max_vote
=
vote
if
self
.
is_equiv
(
elected
,
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
doc
[
"solution"
]))
):
retval
=
1
return
{
"acc"
:
retval
}
class
MathCountingAndProbability
(
Math
):
VERSION
=
1
DATASET_NAME
=
"counting_and_probability"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment