Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
88745155
Commit
88745155
authored
Apr 25, 2022
by
cjlovering
Browse files
Initial integration
parent
6caa0afd
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
480 additions
and
318 deletions
+480
-318
lm_eval/base.py
lm_eval/base.py
+220
-96
lm_eval/evaluator.py
lm_eval/evaluator.py
+70
-35
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+42
-35
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+39
-60
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+42
-30
lm_eval/tasks/race.py
lm_eval/tasks/race.py
+67
-62
No files found.
lm_eval/base.py
View file @
88745155
import
abc
from
typing
import
Iterable
import
numpy
as
np
import
random
import
re
...
...
@@ -24,17 +25,17 @@ class LM(abc.ABC):
@
abstractmethod
def
loglikelihood
(
self
,
requests
):
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
:param requests: list
A list of pairs (context, continuation)
context: str
Context string. Implementations of LM must be able to handle an
Context string. Implementations of LM must be able to handle an
empty context string.
continuation: str
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: list
A list of pairs (logprob, isgreedy)
...
...
@@ -97,7 +98,7 @@ class LM(abc.ABC):
context: str
Context string
until: [str]
The string sequences to generate until. These string sequences
The string sequences to generate until. These string sequences
may each span across multiple tokens, or may be part of one token.
:return: list
A list of strings continuation
...
...
@@ -118,7 +119,6 @@ class LM(abc.ABC):
class
BaseLM
(
LM
):
@
property
@
abstractmethod
def
eot_token_id
(
self
):
...
...
@@ -145,13 +145,16 @@ class BaseLM(LM):
pass
@
abstractmethod
def
tok_encode
(
self
,
string
:
str
):
pass
def
tok_encode
(
self
,
string
:
str
):
pass
@
abstractmethod
def
tok_decode
(
self
,
tokens
:
Iterable
[
int
]):
pass
def
tok_decode
(
self
,
tokens
:
Iterable
[
int
]):
pass
@
abstractmethod
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
pass
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
pass
@
abstractmethod
def
_model_call
(
self
,
inps
):
...
...
@@ -187,23 +190,30 @@ class BaseLM(LM):
# TODO: automatic batch size detection for vectorization
loglikelihoods
=
[]
for
string
,
in
tqdm
(
requests
):
rolling_token_windows
=
list
(
map
(
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
)))
for
(
string
,)
in
tqdm
(
requests
):
rolling_token_windows
=
list
(
map
(
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
),
)
)
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
,
disable_tqdm
=
True
)
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
,
disable_tqdm
=
True
)
# discard is_greedy
string_nll
=
[
x
[
0
]
for
x
in
string_nll
]
string_nll
=
sum
(
string_nll
)
loglikelihoods
.
append
(
string_nll
)
...
...
@@ -223,10 +233,12 @@ class BaseLM(LM):
toks
=
x
[
1
]
+
x
[
2
]
return
-
len
(
toks
),
tuple
(
toks
)
# TODO: automatic (variable) batch size detection for vectorization
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
inps
=
[]
cont_toks_list
=
[]
inplens
=
[]
...
...
@@ -252,44 +264,60 @@ class BaseLM(LM):
# when too long to fit in context, truncate from the left
inp
=
torch
.
tensor
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
):][:
-
1
],
dtype
=
torch
.
long
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
],
dtype
=
torch
.
long
,
).
to
(
self
.
device
)
inplen
,
=
inp
.
shape
(
inplen
,
)
=
inp
.
shape
cont
=
continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length
=
padding_length
if
padding_length
is
not
None
else
inplen
padding_length
=
(
padding_length
if
padding_length
is
not
None
else
inplen
)
# pad length from seq to padding_length
inp
=
torch
.
cat
([
inp
,
# [seq]
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
).
to
(
inp
.
device
)
# [padding_length - seq]
],
dim
=
0
)
inp
=
torch
.
cat
(
[
inp
,
# [seq]
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
).
to
(
inp
.
device
),
# [padding_length - seq]
],
dim
=
0
,
)
inps
.
append
(
inp
.
unsqueeze
(
0
))
# [1, padding_length]
cont_toks_list
.
append
(
cont
)
inplens
.
append
(
inplen
)
batched_inps
=
torch
.
cat
(
inps
,
dim
=
0
)
# [batch, padding_length
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
),
dim
=-
1
).
cpu
()
# [batch, padding_length, vocab]
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
),
dim
=-
1
).
cpu
()
# [batch, padding_length, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
\
in
zip
(
chunk
,
multi_logits
,
inps
,
inplens
,
cont_toks_list
):
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
inps
,
inplens
,
cont_toks_list
):
# Slice to original seq length
contlen
=
len
(
cont_toks
)
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
# [1, seq]
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
# [1, seq]
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
# Answer: (log prob, is-exact-match)
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
...
...
@@ -301,9 +329,9 @@ class BaseLM(LM):
res
.
append
(
answer
)
return
reord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
# TODO: implement fully general `until` that handles untils that are
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
...
...
@@ -312,29 +340,33 @@ class BaseLM(LM):
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
return
len
(
toks
),
x
[
0
]
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
context
,
until
in
tqdm
(
reord
.
get_reordered
()):
if
isinstance
(
until
,
str
):
until
=
[
until
]
primary_until
,
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
([
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]).
to
(
self
.
device
)
(
primary_until
,)
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]:])
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]
:])
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
return
reord
.
get_original
(
res
)
...
...
@@ -383,7 +415,7 @@ class Task(abc.ABC):
self
.
_fewshot_docs
=
None
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
"""
Downloads and returns the task dataset.
"""Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
:param data_dir: str
...
...
@@ -412,7 +444,7 @@ class Task(abc.ABC):
name
=
self
.
DATASET_NAME
,
data_dir
=
data_dir
,
cache_dir
=
cache_dir
,
download_mode
=
download_mode
download_mode
=
download_mode
,
)
@
abstractmethod
...
...
@@ -478,22 +510,22 @@ class Task(abc.ABC):
@
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
pass
@
abstractmethod
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
...
...
@@ -507,7 +539,7 @@ class Task(abc.ABC):
def
aggregation
(
self
):
"""
:returns: {str: [metric_score] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores
"""
pass
...
...
@@ -516,22 +548,26 @@ class Task(abc.ABC):
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
pass
def
fewshot_description
(
self
):
import
warnings
warnings
.
warn
(
"`fewshot_description` will be removed in futures versions. Pass "
"any custom descriptions to the `evaluate` function instead."
,
DeprecationWarning
)
DeprecationWarning
,
)
return
""
@
utils
.
positional_deprecated
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
""" Returns a fewshot context string that is made up of a prepended description
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
...
...
@@ -548,7 +584,9 @@ class Task(abc.ABC):
:returns: str
The fewshot context.
"""
assert
rnd
is
not
None
,
"A `random.Random` generator argument must be provided to `rnd`"
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`"
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
...
...
@@ -556,7 +594,9 @@ class Task(abc.ABC):
)
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description
=
description
+
"
\n\n
"
if
description
else
""
...
...
@@ -569,7 +609,9 @@ class Task(abc.ABC):
else
:
if
self
.
_fewshot_docs
is
None
:
self
.
_fewshot_docs
=
list
(
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
)
fewshotex
=
rnd
.
sample
(
self
.
_fewshot_docs
,
num_fewshot
+
1
)
...
...
@@ -577,23 +619,90 @@ class Task(abc.ABC):
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
labeled_examples
=
"
\n\n
"
.
join
(
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
)
+
"
\n\n
"
labeled_examples
=
(
"
\n\n
"
.
join
(
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
)
+
"
\n\n
"
)
example
=
self
.
doc_to_text
(
doc
)
return
description
+
labeled_examples
+
example
class
MultipleChoiceTask
(
Task
):
class
PromptSourceTask
(
Task
):
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
prompt
=
None
):
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
self
.
prompt
=
prompt
def
doc_to_target
(
self
,
doc
):
return
" "
+
doc
[
'choices'
][
doc
[
'gold'
]]
_
,
target
=
prompt
.
apply
(
doc
)
return
f
"
{
target
}
"
def
doc_to_text
(
self
,
doc
):
text
,
_
=
prompt
.
apply
(
doc
)
return
text
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
_requests
=
[]
if
self
.
prompt
.
metadata
.
choices_in_prompt
:
for
answer_choice
in
prompt
.
get_fixed_answer_choices_list
():
ll_answer_choice
,
_
=
rf
.
loglikelihood
(
ctx
,
f
"
{
answer_choice
}
"
)
_requests
.
append
(
ll_answer_choice
)
else
:
# TODO(Albert): What is the stop symbol? Is it model specific?
ll_greedy
,
_
=
rf
.
greedy_until
(
ctx
,
[
"
\n
Q:"
])
_requests
.
append
(
ll_greedy
)
return
_requests
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
raise
NotImplementedError
(
"Implement process results using the `prompt.metadata.metrics`. See below."
)
if
self
.
prompt
.
metadata
.
choices_in_prompt
:
for
result
,
answer_choice
in
zip
(
prompt
.
get_fixed_answer_choices_list
(),
results
):
pass
else
:
continuation
=
results
# Map metric name to HF metric.
# TODO(Albert): What is Other?
metric_names
=
prompt
.
metadata
.
metrics
class
MultipleChoiceTask
(
Task
):
def
doc_to_target
(
self
,
doc
):
return
" "
+
doc
[
"choices"
][
doc
[
"gold"
]]
def
construct_requests
(
self
,
doc
,
ctx
):
lls
=
[
rf
.
loglikelihood
(
ctx
,
" {}"
.
format
(
choice
))[
0
]
for
choice
in
doc
[
'choices'
]
rf
.
loglikelihood
(
ctx
,
" {}"
.
format
(
choice
))[
0
]
for
choice
in
doc
[
"choices"
]
]
return
lls
...
...
@@ -601,21 +710,21 @@ class MultipleChoiceTask(Task):
def
process_results
(
self
,
doc
,
results
):
gold
=
doc
[
"gold"
]
acc
=
1.
if
np
.
argmax
(
results
)
==
gold
else
0.
acc
=
1.
0
if
np
.
argmax
(
results
)
==
gold
else
0.
0
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
doc
[
"choices"
]])
acc_norm
=
1.
if
np
.
argmax
(
results
/
completion_len
)
==
gold
else
0.
acc_norm
=
1.
0
if
np
.
argmax
(
results
/
completion_len
)
==
gold
else
0.
0
return
{
"acc"
:
acc
,
"acc_norm"
:
acc_norm
,
}
def
higher_is_better
(
self
):
return
{
"acc"
:
True
,
"acc_norm"
:
True
,
}
def
aggregation
(
self
):
return
{
"acc"
:
mean
,
...
...
@@ -624,7 +733,6 @@ class MultipleChoiceTask(Task):
class
PerplexityTask
(
Task
,
abc
.
ABC
):
def
has_training_docs
(
self
):
return
False
...
...
@@ -632,9 +740,15 @@ class PerplexityTask(Task, abc.ABC):
assert
k
==
0
return
[]
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
num_fewshot
==
0
,
"The number of fewshot examples must be 0 for perplexity tasks."
assert
rnd
is
not
None
,
"A `random.Random` generator argument must be provided to `rnd`."
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
(
num_fewshot
==
0
),
"The number of fewshot examples must be 0 for perplexity tasks."
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`."
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
...
...
@@ -642,7 +756,9 @@ class PerplexityTask(Task, abc.ABC):
)
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return
""
...
...
@@ -665,7 +781,7 @@ class PerplexityTask(Task, abc.ABC):
return
req
def
process_results
(
self
,
doc
,
results
):
loglikelihood
,
=
results
(
loglikelihood
,
)
=
results
words
=
self
.
count_words
(
doc
)
bytes_
=
self
.
count_bytes
(
doc
)
return
{
...
...
@@ -687,23 +803,23 @@ class PerplexityTask(Task, abc.ABC):
@
classmethod
def
count_words
(
cls
,
doc
):
"""
Downstream tasks with custom word boundaries should override this!
"""
"""Downstream tasks with custom word boundaries should override this!"""
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
def
hash_args
(
attr
,
args
):
dat
=
json
.
dumps
([
attr
]
+
list
(
args
))
return
hashlib
.
sha256
(
dat
.
encode
(
'
utf-8
'
)).
hexdigest
()
return
hashlib
.
sha256
(
dat
.
encode
(
"
utf-8
"
)).
hexdigest
()
class
CacheHook
:
def
__init__
(
self
,
cachinglm
):
if
cachinglm
is
None
:
if
cachinglm
is
None
:
self
.
dbdict
=
None
return
self
.
dbdict
=
cachinglm
.
dbdict
def
add_partial
(
self
,
attr
,
req
,
res
):
if
self
.
dbdict
is
None
:
return
...
...
@@ -733,7 +849,7 @@ class CachingLM:
def
fn
(
requests
):
res
=
[]
remaining_reqs
=
[]
# figure out which ones are cached and which ones are new
for
req
in
requests
:
hsh
=
hash_args
(
attr
,
req
)
...
...
@@ -746,7 +862,7 @@ class CachingLM:
else
:
res
.
append
(
None
)
remaining_reqs
.
append
(
req
)
# actually run the LM on the requests that do not have cached results
rem_res
=
getattr
(
self
.
lm
,
attr
)(
remaining_reqs
)
...
...
@@ -764,41 +880,48 @@ class CachingLM:
self
.
dbdict
.
commit
()
return
res
return
fn
def
get_cache_hook
(
self
):
return
CacheHook
(
self
)
REQUEST_RETURN_LENGTHS
=
{
'
loglikelihood
'
:
2
,
'
greedy_until
'
:
None
,
'
loglikelihood_rolling
'
:
None
,
"
loglikelihood
"
:
2
,
"
greedy_until
"
:
None
,
"
loglikelihood_rolling
"
:
None
,
}
class
Request
:
def
__init__
(
self
,
request_type
,
args
,
index
=
None
):
if
request_type
not
in
REQUEST_RETURN_LENGTHS
.
keys
():
raise
NotImplementedError
(
'The request type {} is not implemented!'
.
format
(
request_type
))
raise
NotImplementedError
(
"The request type {} is not implemented!"
.
format
(
request_type
)
)
self
.
request_type
=
request_type
self
.
args
=
args
self
.
index
=
index
def
__iter__
(
self
):
if
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]
is
None
:
raise
IndexError
(
'
This request type does not return multiple arguments!
'
)
raise
IndexError
(
"
This request type does not return multiple arguments!
"
)
for
i
in
range
(
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]):
yield
Request
(
self
.
request_type
,
self
.
args
,
i
)
def
__getitem__
(
self
,
i
):
if
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]
is
None
:
raise
IndexError
(
'
This request type does not return multiple arguments!
'
)
raise
IndexError
(
"
This request type does not return multiple arguments!
"
)
return
Request
(
self
.
request_type
,
self
.
args
,
i
)
def
__eq__
(
self
,
other
):
return
self
.
request_type
==
other
.
request_type
and
self
.
args
==
other
.
args
and
self
.
index
==
other
.
index
return
(
self
.
request_type
==
other
.
request_type
and
self
.
args
==
other
.
args
and
self
.
index
==
other
.
index
)
def
__repr__
(
self
):
return
f
"Req_
{
self
.
request_type
}{
self
.
args
}
[
{
self
.
index
}
]
\n
"
...
...
@@ -808,6 +931,7 @@ class RequestFactory:
def
__getattr__
(
self
,
attr
):
def
fn
(
*
args
):
return
Request
(
attr
,
args
)
return
fn
...
...
lm_eval/evaluator.py
View file @
88745155
...
...
@@ -6,21 +6,33 @@ import lm_eval.metrics
import
lm_eval.models
import
lm_eval.tasks
import
lm_eval.base
import
promptsource
import
numpy
as
np
from
promptsource.templates
import
DatasetTemplates
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
@
positional_deprecated
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
check_integrity
=
False
):
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
check_integrity
=
False
,
):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
Name of model or LM object, see lm_eval.models.get_model
:param model_args: Optional[str]
String arguments for each model class, see LM.create_from_arg_string.
String arguments for each model class, see LM.create_from_arg_string.
Ignored if `model` argument is a LM object.
:param tasks: list[Union[str, Task]]
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
...
...
@@ -37,7 +49,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:return
...
...
@@ -49,20 +61,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert
tasks
!=
[],
"No tasks specified"
if
isinstance
(
model
,
str
):
if
model_args
is
None
:
model_args
=
""
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
'batch_size'
:
batch_size
,
'device'
:
device
})
if
model_args
is
None
:
model_args
=
""
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
"batch_size"
:
batch_size
,
"device"
:
device
}
)
else
:
assert
isinstance
(
model
,
lm_eval
.
base
.
LM
)
lm
=
model
if
not
no_cache
:
lm
=
lm_eval
.
base
.
CachingLM
(
lm
,
'lm_cache/'
+
model
+
'_'
+
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
).
replace
(
'/'
,
'-'
)
+
'.db'
lm
,
"lm_cache/"
+
model
+
"_"
+
model_args
.
replace
(
"="
,
"-"
).
replace
(
","
,
"_"
).
replace
(
"/"
,
"-"
)
+
".db"
,
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
_promptsource
(
tasks
)
if
check_integrity
:
run_task_tests
(
task_list
=
tasks
)
...
...
@@ -72,7 +90,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict
=
task_dict
,
num_fewshot
=
num_fewshot
,
limit
=
limit
,
description_dict
=
description_dict
description_dict
=
description_dict
,
)
# add info about the model and few shot config
...
...
@@ -85,14 +103,22 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache"
:
no_cache
,
"limit"
:
limit
,
"bootstrap_iters"
:
bootstrap_iters
,
"description_dict"
:
description_dict
"description_dict"
:
description_dict
,
}
return
results
@
positional_deprecated
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
):
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
...
...
@@ -108,7 +134,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
:param bootstrap_iters:
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
Dictionary of custom task descriptions of the form: `task_name: description`
:return
Dictionary of results
"""
...
...
@@ -118,12 +144,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert
not
provide_description
# not implemented.
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
task_dict_items
=
[
(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
]
results
=
collections
.
defaultdict
(
dict
)
...
...
@@ -158,15 +186,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd
.
seed
(
42
)
rnd
.
shuffle
(
task_docs
)
description
=
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
description
=
(
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
docs
[(
task_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
...
...
@@ -189,11 +218,13 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print
(
"Running"
,
reqtype
,
"requests"
)
resps
=
getattr
(
lm
,
reqtype
)([
req
.
args
for
req
in
reqs
])
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)]
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
]
for
resp
,
(
i
,
task_name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
process_res_queue
[(
task_name
,
doc_id
)].
append
((
i
,
resp
))
vals
=
collections
.
defaultdict
(
list
)
# unpack results and sort back in order and return control to Task
...
...
@@ -207,25 +238,29 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
metrics
=
task
.
process_results
(
doc
,
requests
)
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_name
,
metric
)].
append
(
value
)
task_name
,
prompt_name
=
task_name
.
split
(
"+"
)
results
[
task_name
][
"task_name"
]
=
task_name
results
[
task_name
][
"prompt_name"
]
=
prompt_name
# aggregate results
for
(
task_name
,
metric
),
items
in
vals
.
items
():
task
=
task_dict
[
task_name
]
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
metric
=
task
.
aggregation
()[
metric
],
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
)
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)
}
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)}
def
make_table
(
result_dict
):
...
...
@@ -247,9 +282,9 @@ def make_table(result_dict):
if
m
+
"_stderr"
in
dic
:
se
=
dic
[
m
+
"_stderr"
]
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
'±'
,
'
%.4f
'
%
se
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
"±"
,
"
%.4f
"
%
se
])
else
:
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
''
,
''
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
""
,
""
])
k
=
""
version
=
""
md_writer
.
value_matrix
=
values
...
...
lm_eval/tasks/__init__.py
View file @
88745155
from
promptsource.templates
import
DatasetTemplates
from
pprint
import
pprint
from
typing
import
List
,
Union
...
...
@@ -58,8 +60,8 @@ from . import storycloze
# 6 total
gpt3_translation_benchmarks
=
{
"wmt14"
:
[
'
en-fr
'
,
'
fr-en
'
],
# French
"wmt16"
:
[
'
en-ro
'
,
'
ro-en
'
,
'
de-en
'
,
'
en-de
'
],
# German, Romanian
"wmt14"
:
[
"
en-fr
"
,
"
fr-en
"
],
# French
"wmt16"
:
[
"
en-ro
"
,
"
ro-en
"
,
"
de-en
"
,
"
en-de
"
],
# German, Romanian
}
...
...
@@ -67,7 +69,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks
=
{
**
gpt3_translation_benchmarks
,
"wmt20"
:
sacrebleu
.
get_langpairs_for_testset
(
"wmt20"
),
"iwslt17"
:
[
'
en-ar
'
,
'
ar-en
'
]
# Arabic
"iwslt17"
:
[
"
en-ar
"
,
"
ar-en
"
],
# Arabic
}
# 319 total
...
...
@@ -91,7 +93,7 @@ TASK_REGISTRY = {
"rte"
:
glue
.
RTE
,
"qnli"
:
glue
.
QNLI
,
"qqp"
:
glue
.
QQP
,
#"stsb": glue.STSB, # not implemented yet
#
"stsb": glue.STSB, # not implemented yet
"sst"
:
glue
.
SST
,
"wnli"
:
glue
.
WNLI
,
# SuperGLUE
...
...
@@ -102,34 +104,26 @@ TASK_REGISTRY = {
"record"
:
superglue
.
ReCoRD
,
"wic"
:
superglue
.
WordsInContext
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
# Order by benchmark/genre?
"coqa"
:
coqa
.
CoQA
,
"drop"
:
drop
.
DROP
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
# multilingual lambada
**
lambada_multilingual
.
construct_tasks
(),
"wikitext"
:
wikitext
.
WikiText
,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa"
:
piqa
.
PiQA
,
"prost"
:
prost
.
PROST
,
"mc_taco"
:
mc_taco
.
MCTACO
,
# Science related
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"qasper"
:
qasper
.
QASPER
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"triviaqa"
:
triviaqa
.
TriviaQA
,
"arc_easy"
:
arc
.
ARCEasy
,
"arc_challenge"
:
arc
.
ARCChallenge
,
...
...
@@ -140,7 +134,7 @@ TASK_REGISTRY = {
"squad2"
:
squad
.
SQuAD2
,
"race"
:
race
.
RACE
,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa"
:
headqa
.
HeadQAEsDeprecated
,
# for backwards compat - headqa used to default to es
"headqa"
:
headqa
.
HeadQAEsDeprecated
,
# for backwards compat - headqa used to default to es
"headqa_es"
:
headqa
.
HeadQAEs
,
"headqa_en"
:
headqa
.
HeadQAEn
,
"mathqa"
:
mathqa
.
MathQA
,
...
...
@@ -150,21 +144,17 @@ TASK_REGISTRY = {
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r2"
:
anli
.
ANLIRound2
,
"anli_r3"
:
anli
.
ANLIRound3
,
"ethics_cm"
:
hendrycks_ethics
.
EthicsCM
,
"ethics_deontology"
:
hendrycks_ethics
.
EthicsDeontology
,
"ethics_justice"
:
hendrycks_ethics
.
EthicsJustice
,
"ethics_utilitarianism_original"
:
hendrycks_ethics
.
EthicsUtilitarianismOriginal
,
"ethics_utilitarianism"
:
hendrycks_ethics
.
EthicsUtilitarianism
,
"ethics_virtue"
:
hendrycks_ethics
.
EthicsVirtue
,
"truthfulqa_mc"
:
truthfulqa
.
TruthfulQAMultipleChoice
,
"truthfulqa_gen"
:
truthfulqa
.
TruthfulQAGeneration
,
"truthfulqa_mc"
:
truthfulqa
.
TruthfulQAMultipleChoice
,
"truthfulqa_gen"
:
truthfulqa
.
TruthfulQAGeneration
,
# dialogue
"mutual"
:
mutual
.
MuTual
,
"mutual_plus"
:
mutual
.
MuTualPlus
,
# math
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
...
...
@@ -175,7 +165,6 @@ TASK_REGISTRY = {
"math_precalc"
:
hendrycks_math
.
MathPrecalculus
,
"math_asdiv"
:
asdiv
.
Asdiv
,
"gsm8k"
:
gsm8k
.
GradeSchoolMath8K
,
# arithmetic
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2ds"
:
arithmetic
.
Arithmetic2DMinus
,
...
...
@@ -189,22 +178,18 @@ TASK_REGISTRY = {
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
**
hendrycks_test
.
create_all_tasks
(),
# e.g. wmt14-fr-en
**
translation
.
create_tasks_from_benchmarks
(
gpt3_translation_benchmarks
),
# chef's selection, mostly wmt20
**
translation
.
create_tasks_from_benchmarks
(
selected_translation_benchmarks
),
# Word Scrambling and Manipulation Tasks
"anagrams1"
:
unscramble
.
Anagrams1
,
"anagrams2"
:
unscramble
.
Anagrams2
,
"cycle_letters"
:
unscramble
.
CycleLetters
,
"random_insertion"
:
unscramble
.
RandomInsertion
,
"reversed_words"
:
unscramble
.
ReversedWords
,
# Pile
"pile_arxiv"
:
pile
.
PileArxiv
,
"pile_books3"
:
pile
.
PileBooks3
,
...
...
@@ -228,7 +213,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc"
:
pile
.
PileUbuntuIrc
,
"pile_wikipedia"
:
pile
.
PileWikipedia
,
"pile_youtubesubtitles"
:
pile
.
PileYoutubeSubtitles
,
# BLiMP
"blimp_adjunct_island"
:
blimp
.
BlimpAdjunctIsland
,
"blimp_anaphor_gender_agreement"
:
blimp
.
BlimpAnaphorGenderAgreement
,
...
...
@@ -297,7 +281,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance"
:
blimp
.
BlimpWhVsThatNoGapLongDistance
,
"blimp_wh_vs_that_with_gap"
:
blimp
.
BlimpWhVsThatWithGap
,
"blimp_wh_vs_that_with_gap_long_distance"
:
blimp
.
BlimpWhVsThatWithGapLongDistance
,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
...
...
@@ -321,19 +304,43 @@ def get_task_name_from_object(task_object):
for
name
,
class_
in
TASK_REGISTRY
.
items
():
if
class_
is
task_object
:
return
name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
return
(
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
)
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
lm_eval
.
base
.
Task
]]):
task_name_dict
=
{
task_name
:
get_task
(
task_name
)()
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
}
task_name_from_object_dict
=
{
get_task_name_from_object
(
task_object
):
task_object
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
}
assert
set
(
task_name_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
()))
return
{
**
task_name_dict
,
**
task_name_from_object_dict
}
def
get_task_dict_promptsource
(
task_name_list
:
List
[
str
]):
"""Loads a task instance for each prompt written for that task."""
task_name_dict
=
{}
for
task_name
in
task_name_list
:
assert
isinstance
(
task_name
,
str
)
task_prompts
=
DatasetTemplates
(
task_name
)
for
prompt_name
in
task_prompts
.
all_template_names
:
prompt
=
task_prompts
[
prompt_name
]
# NOTE: We choose a sep that can be easily split.
task_name_dict
[
f
"
{
task_name
}
+
{
prompt_name
}
"
]
=
get_task
(
task_name
)(
prompt
=
prompt
)
return
task_name_dict
lm_eval/tasks/coqa.py
View file @
88745155
...
...
@@ -51,44 +51,22 @@ class CoQA(Task):
def
test_docs
(
self
):
pass
def
doc_to_text
(
self
,
doc
):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
][
"input_text"
],
doc
[
"answers"
][
"input_text"
][:
-
1
]):
# omit target answer ai
question
=
f
"Q:
{
q
}
\n\n
"
answer
=
f
"A:
{
a
}
\n\n
"
if
a
is
not
None
else
"A:"
doc_text
+=
question
+
answer
return
doc_text
@
classmethod
def
get_answers
(
cls
,
doc
,
turn_id
):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers
=
[]
answer_forturn
=
doc
[
"answers"
][
"input_text"
][
turn_id
-
1
]
answers
.
append
(
answer_forturn
)
additional_answers
=
doc
.
get
(
"additional_answers"
)
if
additional_answers
:
for
key
in
additional_answers
:
additional_answer_for_turn
=
additional_answers
[
key
][
"input_text"
][
turn_id
-
1
]
if
additional_answer_for_turn
.
lower
()
not
in
map
(
str
.
lower
,
answers
):
answers
.
append
(
additional_answer_for_turn
)
return
answers
@
classmethod
def
get_answer_choice
(
self
,
raw_text
):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if
raw_text
==
"unknown"
:
return
'0'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"yes"
:
return
'1'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"no"
:
return
'2'
return
'3'
# Not a yes/no question
# @classmethod
# def get_answers(cls, doc, turn_id):
# # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
# answers = []
# answer_forturn = doc["answers"]["input_text"][turn_id - 1]
# answers.append(answer_forturn)
# additional_answers = doc.get("additional_answers")
# if additional_answers:
# for key in additional_answers:
# additional_answer_for_turn = additional_answers[key]["input_text"][
# turn_id - 1
# ]
# if additional_answer_for_turn.lower() not in map(str.lower, answers):
# answers.append(additional_answer_for_turn)
# return answers
@
staticmethod
def
compute_scores
(
gold_list
,
pred
):
...
...
@@ -98,40 +76,38 @@ class CoQA(Task):
em_sum
=
0.0
if
len
(
gold_list
)
>
1
:
for
i
in
range
(
len
(
gold_list
)):
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
# predictions compared against (n) golds and take maximum
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
else
:
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_list
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_list
)
return
{
'em'
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
'f1'
:
f1_sum
/
max
(
1
,
len
(
gold_list
))}
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
# Default to prediction of last turn.
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
][
"input_text"
])
raw_text
=
doc
[
'answers'
][
"input_text"
][
turnid
-
1
]
return
" "
+
raw_text
return
{
"em"
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
}
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
cont_request
=
rf
.
greedy_until
(
ctx
,
[
'
\n
Q:
'
])
cont_request
=
rf
.
greedy_until
(
ctx
,
[
"
\n
Q:
"
])
return
cont_request
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
...
...
@@ -139,15 +115,18 @@ class CoQA(Task):
:param results:
The results of the requests created in construct_requests.
"""
turn_id
=
len
(
doc
[
"questions"
][
"input_text"
])
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
pred
=
results
[
0
].
strip
().
split
(
'
\n
'
)[
0
]
target
=
self
.
doc_to_target
(
doc
).
strip
()
pred
=
results
[
0
].
strip
().
split
(
"
\n
"
)[
0
]
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
scores
=
self
.
compute_scores
(
gold_list
,
pred
)
# TODO: Add HF metrics mapped from promptsource metadata.
scores
=
self
.
compute_scores
([
target
],
pred
)
return
{
"f1"
:
scores
[
'
f1
'
],
"em"
:
scores
[
'
em
'
],
"f1"
:
scores
[
"
f1
"
],
"em"
:
scores
[
"
em
"
],
}
def
higher_is_better
(
self
):
...
...
lm_eval/tasks/drop.py
View file @
88745155
...
...
@@ -70,21 +70,26 @@ class DROP(Task):
@
classmethod
def
get_answers
(
cls
,
qa
):
def
_flatten_validated_answers
(
validated_answers
):
"""
Flattens a dict of lists of validated answers.
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas
=
[]
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
vas
.
append
({
"number"
:
validated_answers
[
"number"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
})
vas
.
append
(
{
"number"
:
validated_answers
[
"number"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
}
)
return
vas
answers
=
[]
answers_set
=
set
()
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
])
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
]
)
for
candidate
in
candidates
:
answer
=
cls
.
parse_answer
(
candidate
)
if
answer
in
answers_set
:
...
...
@@ -100,15 +105,17 @@ class DROP(Task):
return
(
str
(
answer
[
"number"
]),)
if
answer
[
"spans"
]
!=
[]:
return
tuple
(
answer
[
"spans"
])
return
(
" "
.
join
([
answer
[
"date"
][
"day"
],
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]).
strip
(),)
return
(
" "
.
join
(
[
answer
[
"date"
][
"day"
],
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]
).
strip
(),
)
def
doc_to_text
(
self
,
doc
):
return
f
"Passage:
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
#
def doc_to_text(self, doc):
#
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def
doc_to_target
(
self
,
doc
):
return
" "
+
", "
.
join
(
doc
[
"answers"
][
0
])
#
def doc_to_target(self, doc):
#
return " " + ", ".join(doc["answers"][0])
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
...
...
@@ -134,7 +141,13 @@ class DROP(Task):
:param results:
The results of the requests created in construct_requests.
"""
preds
,
golds
=
results
,
doc
[
"answers"
]
pred
=
results
[
0
].
strip
()
target
=
self
.
doc_to_target
(
doc
).
strip
()
preds
=
[
pred
]
golds
=
[
target
]
max_em
=
0
max_f1
=
0
for
gold_answer
in
golds
:
...
...
@@ -142,10 +155,7 @@ class DROP(Task):
if
gold_answer
[
0
].
strip
():
max_em
=
max
(
max_em
,
exact_match
)
max_f1
=
max
(
max_f1
,
f1_score
)
return
{
"em"
:
max_em
,
"f1"
:
max_f1
}
return
{
"em"
:
max_em
,
"f1"
:
max_f1
}
def
get_metrics
(
self
,
predicted
,
gold
):
"""
...
...
@@ -158,7 +168,9 @@ class DROP(Task):
predicted_bags
=
self
.
_answer_to_bags
(
predicted
)
gold_bags
=
self
.
_answer_to_bags
(
gold
)
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
])
==
len
(
gold_bags
[
0
]):
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
]
)
==
len
(
gold_bags
[
0
]):
exact_match
=
1.0
else
:
exact_match
=
0.0
...
...
@@ -190,7 +202,9 @@ class DROP(Task):
for
gold_index
,
gold_item
in
enumerate
(
gold
):
for
pred_index
,
pred_item
in
enumerate
(
predicted
):
if
self
.
_match_numbers_if_present
(
gold_item
,
pred_item
):
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
row_ind
,
col_ind
=
linear_sum_assignment
(
-
scores
)
max_scores
=
np
.
zeros
([
max
(
len
(
gold
),
len
(
predicted
))])
...
...
@@ -256,7 +270,11 @@ class DROP(Task):
def
_normalize
(
self
,
answer
):
tokens
=
[
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))))
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))
)
)
for
token
in
self
.
_tokenize
(
answer
)
]
tokens
=
[
token
for
token
in
tokens
if
token
.
strip
()]
...
...
@@ -269,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return
{
"em"
:
mean
,
"f1"
:
mean
}
return
{
"em"
:
mean
,
"f1"
:
mean
}
def
higher_is_better
(
self
):
"""
...
...
@@ -280,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"em"
:
True
,
"f1"
:
True
}
return
{
"em"
:
True
,
"f1"
:
True
}
lm_eval/tasks/race.py
View file @
88745155
...
...
@@ -40,7 +40,7 @@ class RACE(Task):
DATASET_NAME
=
"high"
cache
=
{}
letter_to_num
=
{
'A'
:
0
,
'B'
:
1
,
'C'
:
2
,
'D'
:
3
}
letter_to_num
=
{
"A"
:
0
,
"B"
:
1
,
"C"
:
2
,
"D"
:
3
}
def
has_training_docs
(
self
):
return
True
...
...
@@ -59,17 +59,27 @@ class RACE(Task):
# is shown that one document is made per passage.
r
=
collections
.
defaultdict
(
list
)
for
item
in
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)[
set
]:
r
[
item
[
'article'
]].
append
(
item
)
res
=
list
(
r
.
values
()
>>
each
(
lambda
x
:
{
'article'
:
x
[
0
][
'article'
],
'problems'
:
x
>>
each
(
lambda
y
:
{
'question'
:
y
[
'question'
],
'answer'
:
y
[
'answer'
],
'options'
:
y
[
'options'
],
})
}))
for
item
in
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)[
set
]:
r
[
item
[
"article"
]].
append
(
item
)
res
=
list
(
r
.
values
()
>>
each
(
lambda
x
:
{
"article"
:
x
[
0
][
"article"
],
"problems"
:
x
>>
each
(
lambda
y
:
{
"question"
:
y
[
"question"
],
"answer"
:
y
[
"answer"
],
"options"
:
y
[
"options"
],
}
),
}
)
)
self
.
cache
[
set
]
=
res
return
res
...
...
@@ -85,49 +95,48 @@ class RACE(Task):
@
classmethod
def
get_answer_option
(
cls
,
problem
):
answer
=
cls
.
letter_to_num
[
problem
[
'
answer
'
]]
return
problem
[
'
options
'
][
answer
]
answer
=
cls
.
letter_to_num
[
problem
[
"
answer
"
]]
return
problem
[
"
options
"
][
answer
]
@
classmethod
def
last_problem
(
cls
,
doc
):
return
doc
[
'problems'
][
-
1
]
def
doc_to_text
(
self
,
doc
):
text
=
'Article: '
+
doc
[
'article'
]
+
'
\n\n
'
for
problem
in
doc
[
'problems'
][:
-
1
]:
if
problem
[
'question'
][
-
6
:]
==
' _ .'
:
text
+=
problem
[
'question'
][
-
5
:]
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
else
:
question
=
'Question: '
+
problem
[
'question'
]
+
'
\n
'
answer
=
'Answer: '
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
text
+=
question
+
answer
text
+=
self
.
last_problem
(
doc
)[
'question'
]
return
text
def
doc_to_target
(
self
,
doc
):
return
" "
+
self
.
get_answer_option
(
self
.
last_problem
(
doc
))
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
problem
=
self
.
last_problem
(
doc
)
ll_choices
=
[
rf
.
loglikelihood
(
ctx
,
" "
+
problem
[
'options'
][
i
])[
0
]
for
i
in
range
(
4
)
]
return
ll_choices
return
doc
[
"problems"
][
-
1
]
# def doc_to_text(self, doc):
# text = 'Article: ' + doc['article'] + '\n\n'
# for problem in doc['problems'][:-1]:
# if problem['question'][-6:] == ' _ .':
# text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
# else:
# question = 'Question: ' + problem['question'] + '\n'
# answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
# text += question + answer
# text += self.last_problem(doc)['question']
# return text
# def doc_to_target(self, doc):
# return " " + self.get_answer_option(self.last_problem(doc))
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
# Requests which will be sent to the LM.
# :param doc:
# The document as returned from training_docs, validation_docs, or test_docs.
# :param ctx: str
# The context string, generated by fewshot_context. This includes the natural
# language description, as well as the few shot examples, and the question
# part of the document for `doc`.
# """
# problem = self.last_problem(doc)
# ll_choices = [
# rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)
# ]
# return ll_choices
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
...
...
@@ -135,28 +144,24 @@ class RACE(Task):
:param results:
The results of the requests created in construct_requests.
"""
gold
=
self
.
letter_to_num
[
self
.
last_problem
(
doc
)[
'answer'
]]
#
gold
=
self
.
letter_to_num
[
self
.
doc_to_target
(
doc
)]
# gold = self.letter_to_num[self.last_problem(doc)["answer"]]
pred
=
np
.
argmax
(
results
)
return
{
"acc"
:
int
(
pred
==
gold
)
}
return
{
"acc"
:
int
(
pred
==
gold
)}
def
aggregation
(
self
):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return
{
"acc"
:
mean
}
return
{
"acc"
:
mean
}
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"acc"
:
True
}
return
{
"acc"
:
True
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment