Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1576e99e
Commit
1576e99e
authored
Jan 13, 2023
by
Albert Jiang
Browse files
allow task-specific descriptions to be passed to
request construction
parent
99fdec0c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
113 additions
and
107 deletions
+113
-107
lm_eval/base.py
lm_eval/base.py
+23
-44
lm_eval/evaluator.py
lm_eval/evaluator.py
+12
-2
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+17
-10
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+0
-1
lm_eval/tasks/hendrycks_math.py
lm_eval/tasks/hendrycks_math.py
+61
-50
No files found.
lm_eval/base.py
View file @
1576e99e
...
@@ -7,6 +7,7 @@ import os
...
@@ -7,6 +7,7 @@ import os
import
json
import
json
import
hashlib
import
hashlib
import
datasets
import
datasets
import
inspect
from
sqlitedict
import
SqliteDict
from
sqlitedict
import
SqliteDict
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
torch
import
torch
...
@@ -329,7 +330,7 @@ class BaseLM(LM):
...
@@ -329,7 +330,7 @@ class BaseLM(LM):
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
multiple_temperature_sample_until
(
self
,
requests
,
k
=
32
,
temperature
=
0.3
):
def
generate
(
self
,
requests
):
res
=
[]
res
=
[]
def
_collate
(
x
):
def
_collate
(
x
):
...
@@ -338,21 +339,33 @@ class BaseLM(LM):
...
@@ -338,21 +339,33 @@ class BaseLM(LM):
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
context
,
until
in
tqdm
(
re_ord
.
get_reordered
()):
for
request
in
tqdm
(
re_ord
.
get_reordered
()):
if
len
(
request
)
==
2
:
# Unpack greedy sample request
context
,
until
,
=
request
k
,
temperature
=
1
,
0.
_model_generate_kwargs
=
{}
elif
len
(
request
)
==
4
:
# Unpack temperature sample request
context
,
until
,
k
,
temperature
=
request
for
key
in
[
"k"
,
"temperature"
]:
assert
key
in
inspect
.
getfullargspec
(
self
.
_model_generate
).
args
,
\
f
"Model generation parameter '
{
key
}
' not accepted as an argument for _model_generate"
_model_generate_kwargs
=
{
"k"
:
k
,
"temperature"
:
temperature
}
else
:
raise
AssertionError
if
isinstance
(
until
,
str
):
if
isinstance
(
until
,
str
):
until
=
[
until
]
until
=
[
until
]
(
primary_until
,)
=
self
.
tok_encode
(
until
[
0
])
(
primary_until
,)
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
(
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
).
to
(
self
.
device
)
).
to
(
self
.
device
)
assert
context_enc
.
shape
[
0
]
==
1
context_enc
=
context_enc
.
expand
(
k
,
context_enc
.
shape
[
1
])
cont
=
self
.
_model_generate
(
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
,
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
,
temperature
=
temperature
**
_model_generate_kwargs
)
)
generated_tokens
=
cont
[:,
context_enc
.
shape
[
1
]:]
generated_tokens
=
cont
[:,
context_enc
.
shape
[
1
]:]
...
@@ -361,50 +374,16 @@ class BaseLM(LM):
...
@@ -361,50 +374,16 @@ class BaseLM(LM):
s
=
[
candidate
.
split
(
term
)[
0
]
for
candidate
in
s
]
s
=
[
candidate
.
split
(
term
)[
0
]
for
candidate
in
s
]
# partial caching
# partial caching
self
.
cache_hook
.
add_partial
(
"multiple_temperature_sample_until"
,
(
context
,
until
,
k
,
temperature
),
s
)
self
.
cache_hook
.
add_partial
(
"generate"
,
(
context
,
until
,
k
,
temperature
),
s
)
res
.
append
(
s
)
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
):
# TODO: implement fully general `until` that handles until that are
# TODO: implement fully general `until` that handles until that are
# multiple tokens or that span multiple tokens correctly
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
# TODO: extract to TokenizedLM?
res
=
[]
return
self
.
generate
(
requests
)
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
return
len
(
toks
),
x
[
0
]
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
context
,
until
in
tqdm
(
re_ord
.
get_reordered
()):
if
isinstance
(
until
,
str
):
until
=
[
until
]
(
primary_until
,)
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]
:])
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
class
Task
(
abc
.
ABC
):
class
Task
(
abc
.
ABC
):
...
@@ -881,7 +860,7 @@ class CachingLM:
...
@@ -881,7 +860,7 @@ class CachingLM:
REQUEST_RETURN_LENGTHS
=
{
REQUEST_RETURN_LENGTHS
=
{
"loglikelihood"
:
2
,
"loglikelihood"
:
2
,
"greedy_until"
:
None
,
"greedy_until"
:
None
,
"
multiple_temperature_sample_until
"
:
None
,
"
generate
"
:
None
,
"loglikelihood_rolling"
:
None
,
"loglikelihood_rolling"
:
None
,
}
}
...
...
lm_eval/evaluator.py
View file @
1576e99e
...
@@ -2,6 +2,7 @@ import collections
...
@@ -2,6 +2,7 @@ import collections
import
itertools
import
itertools
import
numpy
as
np
import
numpy
as
np
import
random
import
random
import
inspect
import
lm_eval.metrics
import
lm_eval.metrics
import
lm_eval.models
import
lm_eval.models
import
lm_eval.tasks
import
lm_eval.tasks
...
@@ -177,6 +178,7 @@ def evaluate(
...
@@ -177,6 +178,7 @@ def evaluate(
docs
=
{}
docs
=
{}
docs_for_decontamination
=
collections
.
defaultdict
(
list
)
docs_for_decontamination
=
collections
.
defaultdict
(
list
)
task_to_description
=
{}
# get lists of each type of request
# get lists of each type of request
for
task_name
,
task
in
task_dict_items
:
for
task_name
,
task
in
task_dict_items
:
...
@@ -203,6 +205,7 @@ def evaluate(
...
@@ -203,6 +205,7 @@ def evaluate(
if
description_dict
and
task_name
in
description_dict
if
description_dict
and
task_name
in
description_dict
else
""
else
""
)
)
task_to_description
[
task_name
]
=
description
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
...
@@ -215,7 +218,10 @@ def evaluate(
...
@@ -215,7 +218,10 @@ def evaluate(
ctx
=
task
.
fewshot_context
(
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
"description"
in
inspect
.
getfullargspec
(
task
.
construct_requests
).
args
:
reqs
=
task
.
construct_requests
(
doc
,
ctx
,
description
=
description
)
else
:
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
reqs
=
[
reqs
]
reqs
=
[
reqs
]
for
i
,
req
in
enumerate
(
reqs
):
for
i
,
req
in
enumerate
(
reqs
):
...
@@ -262,7 +268,11 @@ def evaluate(
...
@@ -262,7 +268,11 @@ def evaluate(
task
=
task_dict
[
task_name
]
task
=
task_dict
[
task_name
]
doc
=
docs
[(
task_name
,
doc_id
)]
doc
=
docs
[(
task_name
,
doc_id
)]
metrics
=
task
.
process_results
(
doc
,
requests
)
# be backward compatible with tasks that do not allow description_dict in process_results
if
"description"
in
inspect
.
getfullargspec
(
task
.
process_results
).
args
:
metrics
=
task
.
process_results
(
doc
,
requests
,
task_to_description
[
task_name
])
else
:
metrics
=
task
.
process_results
(
doc
,
requests
)
for
metric
,
value
in
metrics
.
items
():
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_name
,
metric
)].
append
(
value
)
vals
[(
task_name
,
metric
)].
append
(
value
)
...
...
lm_eval/models/gpt2.py
View file @
1576e99e
...
@@ -121,16 +121,23 @@ class HFLM(BaseLM):
...
@@ -121,16 +121,23 @@ class HFLM(BaseLM):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
gpt2
(
inps
)[
0
][:,
:,
:
50257
]
return
self
.
gpt2
(
inps
)[
0
][:,
:,
:
50257
]
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
,
temperature
=
0.
):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
,
k
=
1
,
temperature
=
0.
):
assert
temperature
>=
0.
assert
(
isinstance
(
k
,
int
)
and
k
>=
1
),
f
"Incorrect number of candidates to generate:
{
k
}
"
if
temperature
==
0.
:
assert
temperature
>=
0.
,
f
"Negative sampling temperature:
{
temperature
}
"
return
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
False
# Whether to sample or to decode greedily
)
do_sample
=
(
temperature
!=
0.
)
else
:
if
not
do_sample
:
return
self
.
gpt2
.
generate
(
# If decoding greedily, only sample once
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
True
,
temperature
=
temperature
assert
k
==
1
,
f
"Decoding greedily but
{
k
}
generations"
)
if
k
>
1
:
context
=
context
.
expand
(
k
,
context
.
shape
[
1
])
return
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
do_sample
=
do_sample
,
temperature
=
temperature
)
# for backwards compatibility
# for backwards compatibility
...
...
lm_eval/tasks/__init__.py
View file @
1576e99e
...
@@ -157,7 +157,6 @@ TASK_REGISTRY = {
...
@@ -157,7 +157,6 @@ TASK_REGISTRY = {
"mutual_plus"
:
mutual
.
MuTualPlus
,
"mutual_plus"
:
mutual
.
MuTualPlus
,
# math
# math
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_algebra_maj@k"
:
hendrycks_math
.
MathAlgebraMaj
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
"math_geometry"
:
hendrycks_math
.
MathGeometry
,
"math_geometry"
:
hendrycks_math
.
MathGeometry
,
"math_intermediate_algebra"
:
hendrycks_math
.
MathIntermediateAlgebra
,
"math_intermediate_algebra"
:
hendrycks_math
.
MathIntermediateAlgebra
,
...
...
lm_eval/tasks/hendrycks_math.py
View file @
1576e99e
...
@@ -27,6 +27,8 @@ _CITATION = """
...
@@ -27,6 +27,8 @@ _CITATION = """
class
Math
(
Task
):
class
Math
(
Task
):
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
hendrycks_math
.
hendrycks_math
)
DATASET_PATH
=
inspect
.
getfile
(
lm_eval
.
datasets
.
hendrycks_math
.
hendrycks_math
)
DATASET_NAME
=
None
DATASET_NAME
=
None
MAJORITY_VOTING
=
"majority_voting"
SAMPLING_TEMPERATURE
=
"sampling_temperature"
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
True
return
True
...
@@ -62,21 +64,72 @@ class Math(Task):
...
@@ -62,21 +64,72 @@ class Math(Task):
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
" "
+
doc
[
"solution"
]
return
" "
+
doc
[
"solution"
]
def
construct_requests
(
self
,
doc
,
ctx
):
def
parse_description
(
self
,
description
):
return
rf
.
greedy_until
(
ctx
,
[
"
\n
"
])
"""description is a string with comma-separated key=value tuples
e.g.:
"majority_voting=32,sampling_temperature=1.0"
"""
parsed_dict
=
{}
for
term
in
description
.
split
(
","
):
if
not
term
.
strip
():
continue
key
,
value
=
term
.
split
(
"="
)
parsed_dict
[
key
]
=
value
return
parsed_dict
def
construct_requests
(
self
,
doc
,
ctx
,
description
=
""
):
if
not
description
.
strip
():
return
rf
.
generate
(
ctx
,
[
"
\n
"
])
parsed_description
=
self
.
parse_description
(
description
=
description
)
majority_voting_value
=
int
(
parsed_description
.
get
(
self
.
MAJORITY_VOTING
,
1
))
sampling_temperature_value
=
float
(
parsed_description
.
get
(
self
.
SAMPLING_TEMPERATURE
,
1.0
))
return
rf
.
generate
(
ctx
,
[
"
\n
"
],
majority_voting_value
,
sampling_temperature_value
)
def
get_pure_answer
(
self
,
candidate
):
indices
=
[
pos
for
pos
,
char
in
enumerate
(
candidate
)
if
char
==
"$"
]
if
len
(
indices
)
<=
1
:
return
candidate
return
candidate
[
indices
[
0
]
+
1
:
indices
[
-
1
]]
def
process_results
(
self
,
doc
,
results
):
def
majority_vote
(
self
,
candidates
):
answers
=
[]
for
candidate
in
candidates
:
answer
=
self
.
get_pure_answer
(
candidate
)
try
:
answer
=
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
answer
))
except
:
answer
=
None
answers
.
append
(
answer
)
answer_votes
=
{}
for
answer
in
answers
:
answer_votes
[
answer
]
=
answer_votes
.
get
(
answer
,
0
)
+
1
max_vote
=
0
elected
=
None
for
answer
,
vote
in
answer_votes
.
items
():
if
vote
>
max_vote
and
answer
is
not
None
:
elected
=
answer
max_vote
=
vote
return
elected
def
process_results
(
self
,
doc
,
results
,
description
=
""
):
retval
=
0
retval
=
0
indices
=
[
pos
for
pos
,
char
in
enumerate
(
results
[
0
])
if
char
==
"$"
]
if
len
(
indices
)
<=
1
:
answer
=
results
[
0
]
else
:
answer
=
results
[
0
][
indices
[
0
]
+
1
:
indices
[
-
1
]]
if
description
==
""
:
answer
=
self
.
get_pure_answer
(
results
[
0
])
elif
self
.
MAJORITY_VOTING
in
self
.
parse_description
(
description
):
answer
=
self
.
majority_vote
(
results
[
0
])
else
:
raise
AssertionError
if
self
.
is_equiv
(
if
self
.
is_equiv
(
answer
,
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
doc
[
"solution"
]))
answer
,
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
doc
[
"solution"
]))
):
):
retval
=
1
retval
=
1
return
{
"acc"
:
retval
}
return
{
"acc"
:
retval
}
def
aggregation
(
self
):
def
aggregation
(
self
):
...
@@ -286,48 +339,6 @@ class MathAlgebra(Math):
...
@@ -286,48 +339,6 @@ class MathAlgebra(Math):
DATASET_NAME
=
"algebra"
DATASET_NAME
=
"algebra"
class
MathAlgebraMaj
(
Math
):
VERSION
=
1
DATASET_NAME
=
"algebra"
def
construct_requests
(
self
,
doc
,
ctx
):
return
rf
.
multiple_temperature_sample_until
(
ctx
,
[
"
\n
"
])
def
process_results
(
self
,
doc
,
results
):
retval
=
0
candidates
=
results
[
0
]
answers
=
[]
for
candidate
in
candidates
:
indices
=
[
pos
for
pos
,
char
in
enumerate
(
candidate
)
if
char
==
"$"
]
if
len
(
indices
)
<=
1
:
answer
=
candidate
else
:
answer
=
candidate
[
indices
[
0
]
+
1
:
indices
[
-
1
]]
try
:
answer
=
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
answer
))
except
:
answer
=
None
answers
.
append
(
answer
)
answer_votes
=
{}
for
answer
in
answers
:
answer_votes
[
answer
]
=
answer_votes
.
get
(
answer
,
0
)
+
1
max_vote
=
0
elected
=
None
for
answer
,
vote
in
answer_votes
.
items
():
if
vote
>
max_vote
and
answer
is
not
None
:
elected
=
answer
max_vote
=
vote
if
self
.
is_equiv
(
elected
,
self
.
remove_boxed
(
self
.
last_boxed_only_string
(
doc
[
"solution"
]))
):
retval
=
1
return
{
"acc"
:
retval
}
class
MathCountingAndProbability
(
Math
):
class
MathCountingAndProbability
(
Math
):
VERSION
=
1
VERSION
=
1
DATASET_NAME
=
"counting_and_probability"
DATASET_NAME
=
"counting_and_probability"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment