Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
88745155
Commit
88745155
authored
Apr 25, 2022
by
cjlovering
Browse files
Initial integration
parent
6caa0afd
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
480 additions
and
318 deletions
+480
-318
lm_eval/base.py
lm_eval/base.py
+220
-96
lm_eval/evaluator.py
lm_eval/evaluator.py
+70
-35
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+42
-35
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+39
-60
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+42
-30
lm_eval/tasks/race.py
lm_eval/tasks/race.py
+67
-62
No files found.
lm_eval/base.py
View file @
88745155
import
abc
import
abc
from
typing
import
Iterable
from
typing
import
Iterable
import
numpy
as
np
import
numpy
as
np
import
random
import
random
import
re
import
re
...
@@ -118,7 +119,6 @@ class LM(abc.ABC):
...
@@ -118,7 +119,6 @@ class LM(abc.ABC):
class
BaseLM
(
LM
):
class
BaseLM
(
LM
):
@
property
@
property
@
abstractmethod
@
abstractmethod
def
eot_token_id
(
self
):
def
eot_token_id
(
self
):
...
@@ -145,13 +145,16 @@ class BaseLM(LM):
...
@@ -145,13 +145,16 @@ class BaseLM(LM):
pass
pass
@
abstractmethod
@
abstractmethod
def
tok_encode
(
self
,
string
:
str
):
pass
def
tok_encode
(
self
,
string
:
str
):
pass
@
abstractmethod
@
abstractmethod
def
tok_decode
(
self
,
tokens
:
Iterable
[
int
]):
pass
def
tok_decode
(
self
,
tokens
:
Iterable
[
int
]):
pass
@
abstractmethod
@
abstractmethod
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
pass
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
pass
@
abstractmethod
@
abstractmethod
def
_model_call
(
self
,
inps
):
def
_model_call
(
self
,
inps
):
...
@@ -187,19 +190,26 @@ class BaseLM(LM):
...
@@ -187,19 +190,26 @@ class BaseLM(LM):
# TODO: automatic batch size detection for vectorization
# TODO: automatic batch size detection for vectorization
loglikelihoods
=
[]
loglikelihoods
=
[]
for
string
,
in
tqdm
(
requests
):
for
(
string
,)
in
tqdm
(
requests
):
rolling_token_windows
=
list
(
map
(
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
rolling_token_windows
=
list
(
map
(
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
context_len
=
1
,
)))
),
)
)
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
# that
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
,
disable_tqdm
=
True
)
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
,
disable_tqdm
=
True
)
# discard is_greedy
# discard is_greedy
string_nll
=
[
x
[
0
]
for
x
in
string_nll
]
string_nll
=
[
x
[
0
]
for
x
in
string_nll
]
...
@@ -226,7 +236,9 @@ class BaseLM(LM):
...
@@ -226,7 +236,9 @@ class BaseLM(LM):
# TODO: automatic (variable) batch size detection for vectorization
# TODO: automatic (variable) batch size detection for vectorization
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
inps
=
[]
inps
=
[]
cont_toks_list
=
[]
cont_toks_list
=
[]
inplens
=
[]
inplens
=
[]
...
@@ -252,44 +264,60 @@ class BaseLM(LM):
...
@@ -252,44 +264,60 @@ class BaseLM(LM):
# when too long to fit in context, truncate from the left
# when too long to fit in context, truncate from the left
inp
=
torch
.
tensor
(
inp
=
torch
.
tensor
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
):][:
-
1
],
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
],
dtype
=
torch
.
long
dtype
=
torch
.
long
,
).
to
(
self
.
device
)
).
to
(
self
.
device
)
inplen
,
=
inp
.
shape
(
inplen
,
)
=
inp
.
shape
cont
=
continuation_enc
cont
=
continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length
=
padding_length
if
padding_length
is
not
None
else
inplen
padding_length
=
(
padding_length
if
padding_length
is
not
None
else
inplen
)
# pad length from seq to padding_length
# pad length from seq to padding_length
inp
=
torch
.
cat
([
inp
=
torch
.
cat
(
[
inp
,
# [seq]
inp
,
# [seq]
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
).
to
(
inp
.
device
)
# [padding_length - seq]
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
).
to
(
],
dim
=
0
)
inp
.
device
),
# [padding_length - seq]
],
dim
=
0
,
)
inps
.
append
(
inp
.
unsqueeze
(
0
))
# [1, padding_length]
inps
.
append
(
inp
.
unsqueeze
(
0
))
# [1, padding_length]
cont_toks_list
.
append
(
cont
)
cont_toks_list
.
append
(
cont
)
inplens
.
append
(
inplen
)
inplens
.
append
(
inplen
)
batched_inps
=
torch
.
cat
(
inps
,
dim
=
0
)
# [batch, padding_length
batched_inps
=
torch
.
cat
(
inps
,
dim
=
0
)
# [batch, padding_length
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
),
dim
=-
1
).
cpu
()
# [batch, padding_length, vocab]
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
),
dim
=-
1
).
cpu
()
# [batch, padding_length, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
\
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
in
zip
(
in
zip
(
chunk
,
multi_logits
,
inps
,
inplens
,
cont_toks_list
):
chunk
,
multi_logits
,
inps
,
inplens
,
cont_toks_list
):
# Slice to original seq length
# Slice to original seq length
contlen
=
len
(
cont_toks
)
contlen
=
len
(
cont_toks
)
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
# Check if per-token argmax is exactly equal to continuation
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
# [1, seq]
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
# [1, seq]
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
# Obtain log-probs at the corresponding continuation token indices
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
# Answer: (log prob, is-exact-match)
# Answer: (log prob, is-exact-match)
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
...
@@ -319,13 +347,17 @@ class BaseLM(LM):
...
@@ -319,13 +347,17 @@ class BaseLM(LM):
if
isinstance
(
until
,
str
):
if
isinstance
(
until
,
str
):
until
=
[
until
]
until
=
[
until
]
primary_until
,
=
self
.
tok_encode
(
until
[
0
])
(
primary_until
,
)
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
([
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]).
to
(
self
.
device
)
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]:])
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]
:])
for
term
in
until
:
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
s
=
s
.
split
(
term
)[
0
]
...
@@ -383,7 +415,7 @@ class Task(abc.ABC):
...
@@ -383,7 +415,7 @@ class Task(abc.ABC):
self
.
_fewshot_docs
=
None
self
.
_fewshot_docs
=
None
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
"""
Downloads and returns the task dataset.
"""Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
Override this method to download the dataset from a custom API.
:param data_dir: str
:param data_dir: str
...
@@ -412,7 +444,7 @@ class Task(abc.ABC):
...
@@ -412,7 +444,7 @@ class Task(abc.ABC):
name
=
self
.
DATASET_NAME
,
name
=
self
.
DATASET_NAME
,
data_dir
=
data_dir
,
data_dir
=
data_dir
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
download_mode
=
download_mode
download_mode
=
download_mode
,
)
)
@
abstractmethod
@
abstractmethod
...
@@ -478,7 +510,7 @@ class Task(abc.ABC):
...
@@ -478,7 +510,7 @@ class Task(abc.ABC):
@
abstractmethod
@
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
:param doc:
:param doc:
...
@@ -523,15 +555,19 @@ class Task(abc.ABC):
...
@@ -523,15 +555,19 @@ class Task(abc.ABC):
def
fewshot_description
(
self
):
def
fewshot_description
(
self
):
import
warnings
import
warnings
warnings
.
warn
(
warnings
.
warn
(
"`fewshot_description` will be removed in futures versions. Pass "
"`fewshot_description` will be removed in futures versions. Pass "
"any custom descriptions to the `evaluate` function instead."
,
"any custom descriptions to the `evaluate` function instead."
,
DeprecationWarning
)
DeprecationWarning
,
)
return
""
return
""
@
utils
.
positional_deprecated
@
utils
.
positional_deprecated
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
def
fewshot_context
(
""" Returns a fewshot context string that is made up of a prepended description
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
:param doc: str
...
@@ -548,7 +584,9 @@ class Task(abc.ABC):
...
@@ -548,7 +584,9 @@ class Task(abc.ABC):
:returns: str
:returns: str
The fewshot context.
The fewshot context.
"""
"""
assert
rnd
is
not
None
,
"A `random.Random` generator argument must be provided to `rnd`"
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`"
assert
not
provide_description
,
(
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"a custom description to the context, supply the corresponding string via the "
...
@@ -556,7 +594,9 @@ class Task(abc.ABC):
...
@@ -556,7 +594,9 @@ class Task(abc.ABC):
)
)
if
provide_description
is
not
None
:
if
provide_description
is
not
None
:
# nudge people to not specify it at all
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description
=
description
+
"
\n\n
"
if
description
else
""
description
=
description
+
"
\n\n
"
if
description
else
""
...
@@ -569,7 +609,9 @@ class Task(abc.ABC):
...
@@ -569,7 +609,9 @@ class Task(abc.ABC):
else
:
else
:
if
self
.
_fewshot_docs
is
None
:
if
self
.
_fewshot_docs
is
None
:
self
.
_fewshot_docs
=
list
(
self
.
_fewshot_docs
=
list
(
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
)
)
fewshotex
=
rnd
.
sample
(
self
.
_fewshot_docs
,
num_fewshot
+
1
)
fewshotex
=
rnd
.
sample
(
self
.
_fewshot_docs
,
num_fewshot
+
1
)
...
@@ -577,23 +619,90 @@ class Task(abc.ABC):
...
@@ -577,23 +619,90 @@ class Task(abc.ABC):
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
labeled_examples
=
"
\n\n
"
.
join
(
labeled_examples
=
(
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
"
\n\n
"
.
join
(
)
+
"
\n\n
"
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
)
+
"
\n\n
"
)
example
=
self
.
doc_to_text
(
doc
)
example
=
self
.
doc_to_text
(
doc
)
return
description
+
labeled_examples
+
example
return
description
+
labeled_examples
+
example
class
MultipleChoiceTask
(
Task
):
class
PromptSourceTask
(
Task
):
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
prompt
=
None
):
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
self
.
prompt
=
prompt
def
doc_to_target
(
self
,
doc
):
_
,
target
=
prompt
.
apply
(
doc
)
return
f
"
{
target
}
"
def
doc_to_text
(
self
,
doc
):
text
,
_
=
prompt
.
apply
(
doc
)
return
text
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
_requests
=
[]
if
self
.
prompt
.
metadata
.
choices_in_prompt
:
for
answer_choice
in
prompt
.
get_fixed_answer_choices_list
():
ll_answer_choice
,
_
=
rf
.
loglikelihood
(
ctx
,
f
"
{
answer_choice
}
"
)
_requests
.
append
(
ll_answer_choice
)
else
:
# TODO(Albert): What is the stop symbol? Is it model specific?
ll_greedy
,
_
=
rf
.
greedy_until
(
ctx
,
[
"
\n
Q:"
])
_requests
.
append
(
ll_greedy
)
return
_requests
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
raise
NotImplementedError
(
"Implement process results using the `prompt.metadata.metrics`. See below."
)
if
self
.
prompt
.
metadata
.
choices_in_prompt
:
for
result
,
answer_choice
in
zip
(
prompt
.
get_fixed_answer_choices_list
(),
results
):
pass
else
:
continuation
=
results
# Map metric name to HF metric.
# TODO(Albert): What is Other?
metric_names
=
prompt
.
metadata
.
metrics
class
MultipleChoiceTask
(
Task
):
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
return
" "
+
doc
[
'
choices
'
][
doc
[
'
gold
'
]]
return
" "
+
doc
[
"
choices
"
][
doc
[
"
gold
"
]]
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
lls
=
[
lls
=
[
rf
.
loglikelihood
(
ctx
,
" {}"
.
format
(
choice
))[
0
]
rf
.
loglikelihood
(
ctx
,
" {}"
.
format
(
choice
))[
0
]
for
choice
in
doc
[
"choices"
]
for
choice
in
doc
[
'choices'
]
]
]
return
lls
return
lls
...
@@ -601,9 +710,9 @@ class MultipleChoiceTask(Task):
...
@@ -601,9 +710,9 @@ class MultipleChoiceTask(Task):
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
gold
=
doc
[
"gold"
]
gold
=
doc
[
"gold"
]
acc
=
1.
if
np
.
argmax
(
results
)
==
gold
else
0.
acc
=
1.
0
if
np
.
argmax
(
results
)
==
gold
else
0.
0
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
doc
[
"choices"
]])
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
doc
[
"choices"
]])
acc_norm
=
1.
if
np
.
argmax
(
results
/
completion_len
)
==
gold
else
0.
acc_norm
=
1.
0
if
np
.
argmax
(
results
/
completion_len
)
==
gold
else
0.
0
return
{
return
{
"acc"
:
acc
,
"acc"
:
acc
,
...
@@ -624,7 +733,6 @@ class MultipleChoiceTask(Task):
...
@@ -624,7 +733,6 @@ class MultipleChoiceTask(Task):
class
PerplexityTask
(
Task
,
abc
.
ABC
):
class
PerplexityTask
(
Task
,
abc
.
ABC
):
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
False
return
False
...
@@ -632,9 +740,15 @@ class PerplexityTask(Task, abc.ABC):
...
@@ -632,9 +740,15 @@ class PerplexityTask(Task, abc.ABC):
assert
k
==
0
assert
k
==
0
return
[]
return
[]
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
def
fewshot_context
(
assert
num_fewshot
==
0
,
"The number of fewshot examples must be 0 for perplexity tasks."
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
assert
rnd
is
not
None
,
"A `random.Random` generator argument must be provided to `rnd`."
):
assert
(
num_fewshot
==
0
),
"The number of fewshot examples must be 0 for perplexity tasks."
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`."
assert
not
provide_description
,
(
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"a custom description to the context, supply the corresponding string via the "
...
@@ -642,7 +756,9 @@ class PerplexityTask(Task, abc.ABC):
...
@@ -642,7 +756,9 @@ class PerplexityTask(Task, abc.ABC):
)
)
if
provide_description
is
not
None
:
if
provide_description
is
not
None
:
# nudge people to not specify it at all
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return
""
return
""
...
@@ -665,7 +781,7 @@ class PerplexityTask(Task, abc.ABC):
...
@@ -665,7 +781,7 @@ class PerplexityTask(Task, abc.ABC):
return
req
return
req
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
loglikelihood
,
=
results
(
loglikelihood
,
)
=
results
words
=
self
.
count_words
(
doc
)
words
=
self
.
count_words
(
doc
)
bytes_
=
self
.
count_bytes
(
doc
)
bytes_
=
self
.
count_bytes
(
doc
)
return
{
return
{
...
@@ -687,13 +803,13 @@ class PerplexityTask(Task, abc.ABC):
...
@@ -687,13 +803,13 @@ class PerplexityTask(Task, abc.ABC):
@
classmethod
@
classmethod
def
count_words
(
cls
,
doc
):
def
count_words
(
cls
,
doc
):
"""
Downstream tasks with custom word boundaries should override this!
"""
"""Downstream tasks with custom word boundaries should override this!"""
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
def
hash_args
(
attr
,
args
):
def
hash_args
(
attr
,
args
):
dat
=
json
.
dumps
([
attr
]
+
list
(
args
))
dat
=
json
.
dumps
([
attr
]
+
list
(
args
))
return
hashlib
.
sha256
(
dat
.
encode
(
'
utf-8
'
)).
hexdigest
()
return
hashlib
.
sha256
(
dat
.
encode
(
"
utf-8
"
)).
hexdigest
()
class
CacheHook
:
class
CacheHook
:
...
@@ -764,6 +880,7 @@ class CachingLM:
...
@@ -764,6 +880,7 @@ class CachingLM:
self
.
dbdict
.
commit
()
self
.
dbdict
.
commit
()
return
res
return
res
return
fn
return
fn
def
get_cache_hook
(
self
):
def
get_cache_hook
(
self
):
...
@@ -771,16 +888,18 @@ class CachingLM:
...
@@ -771,16 +888,18 @@ class CachingLM:
REQUEST_RETURN_LENGTHS
=
{
REQUEST_RETURN_LENGTHS
=
{
'
loglikelihood
'
:
2
,
"
loglikelihood
"
:
2
,
'
greedy_until
'
:
None
,
"
greedy_until
"
:
None
,
'
loglikelihood_rolling
'
:
None
,
"
loglikelihood_rolling
"
:
None
,
}
}
class
Request
:
class
Request
:
def
__init__
(
self
,
request_type
,
args
,
index
=
None
):
def
__init__
(
self
,
request_type
,
args
,
index
=
None
):
if
request_type
not
in
REQUEST_RETURN_LENGTHS
.
keys
():
if
request_type
not
in
REQUEST_RETURN_LENGTHS
.
keys
():
raise
NotImplementedError
(
'The request type {} is not implemented!'
.
format
(
request_type
))
raise
NotImplementedError
(
"The request type {} is not implemented!"
.
format
(
request_type
)
)
self
.
request_type
=
request_type
self
.
request_type
=
request_type
self
.
args
=
args
self
.
args
=
args
...
@@ -788,17 +907,21 @@ class Request:
...
@@ -788,17 +907,21 @@ class Request:
def
__iter__
(
self
):
def
__iter__
(
self
):
if
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]
is
None
:
if
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]
is
None
:
raise
IndexError
(
'
This request type does not return multiple arguments!
'
)
raise
IndexError
(
"
This request type does not return multiple arguments!
"
)
for
i
in
range
(
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]):
for
i
in
range
(
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]):
yield
Request
(
self
.
request_type
,
self
.
args
,
i
)
yield
Request
(
self
.
request_type
,
self
.
args
,
i
)
def
__getitem__
(
self
,
i
):
def
__getitem__
(
self
,
i
):
if
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]
is
None
:
if
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]
is
None
:
raise
IndexError
(
'
This request type does not return multiple arguments!
'
)
raise
IndexError
(
"
This request type does not return multiple arguments!
"
)
return
Request
(
self
.
request_type
,
self
.
args
,
i
)
return
Request
(
self
.
request_type
,
self
.
args
,
i
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
self
.
request_type
==
other
.
request_type
and
self
.
args
==
other
.
args
and
self
.
index
==
other
.
index
return
(
self
.
request_type
==
other
.
request_type
and
self
.
args
==
other
.
args
and
self
.
index
==
other
.
index
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
"Req_
{
self
.
request_type
}{
self
.
args
}
[
{
self
.
index
}
]
\n
"
return
f
"Req_
{
self
.
request_type
}{
self
.
args
}
[
{
self
.
index
}
]
\n
"
...
@@ -808,6 +931,7 @@ class RequestFactory:
...
@@ -808,6 +931,7 @@ class RequestFactory:
def
__getattr__
(
self
,
attr
):
def
__getattr__
(
self
,
attr
):
def
fn
(
*
args
):
def
fn
(
*
args
):
return
Request
(
attr
,
args
)
return
Request
(
attr
,
args
)
return
fn
return
fn
...
...
lm_eval/evaluator.py
View file @
88745155
...
@@ -6,15 +6,27 @@ import lm_eval.metrics
...
@@ -6,15 +6,27 @@ import lm_eval.metrics
import
lm_eval.models
import
lm_eval.models
import
lm_eval.tasks
import
lm_eval.tasks
import
lm_eval.base
import
lm_eval.base
import
promptsource
import
numpy
as
np
import
numpy
as
np
from
promptsource.templates
import
DatasetTemplates
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
@
positional_deprecated
@
positional_deprecated
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
def
simple_evaluate
(
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
model
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
model_args
=
None
,
description_dict
=
None
,
check_integrity
=
False
):
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
check_integrity
=
False
,
):
"""Instantiate and evaluate a model on a list of tasks.
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
:param model: Union[str, LM]
...
@@ -49,20 +61,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
...
@@ -49,20 +61,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert
tasks
!=
[],
"No tasks specified"
assert
tasks
!=
[],
"No tasks specified"
if
isinstance
(
model
,
str
):
if
isinstance
(
model
,
str
):
if
model_args
is
None
:
model_args
=
""
if
model_args
is
None
:
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
model_args
=
""
'batch_size'
:
batch_size
,
'device'
:
device
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
})
model_args
,
{
"batch_size"
:
batch_size
,
"device"
:
device
}
)
else
:
else
:
assert
isinstance
(
model
,
lm_eval
.
base
.
LM
)
assert
isinstance
(
model
,
lm_eval
.
base
.
LM
)
lm
=
model
lm
=
model
if
not
no_cache
:
if
not
no_cache
:
lm
=
lm_eval
.
base
.
CachingLM
(
lm
=
lm_eval
.
base
.
CachingLM
(
lm
,
'lm_cache/'
+
model
+
'_'
+
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
).
replace
(
'/'
,
'-'
)
+
'.db'
lm
,
"lm_cache/"
+
model
+
"_"
+
model_args
.
replace
(
"="
,
"-"
).
replace
(
","
,
"_"
).
replace
(
"/"
,
"-"
)
+
".db"
,
)
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
_promptsource
(
tasks
)
if
check_integrity
:
if
check_integrity
:
run_task_tests
(
task_list
=
tasks
)
run_task_tests
(
task_list
=
tasks
)
...
@@ -72,7 +90,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
...
@@ -72,7 +90,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict
=
task_dict
,
task_dict
=
task_dict
,
num_fewshot
=
num_fewshot
,
num_fewshot
=
num_fewshot
,
limit
=
limit
,
limit
=
limit
,
description_dict
=
description_dict
description_dict
=
description_dict
,
)
)
# add info about the model and few shot config
# add info about the model and few shot config
...
@@ -85,14 +103,22 @@ def simple_evaluate(model, model_args=None, tasks=[],
...
@@ -85,14 +103,22 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache"
:
no_cache
,
"no_cache"
:
no_cache
,
"limit"
:
limit
,
"limit"
:
limit
,
"bootstrap_iters"
:
bootstrap_iters
,
"bootstrap_iters"
:
bootstrap_iters
,
"description_dict"
:
description_dict
"description_dict"
:
description_dict
,
}
}
return
results
return
results
@
positional_deprecated
@
positional_deprecated
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
):
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
):
"""Instantiate and evaluate a model on a list of tasks.
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
:param lm: obj
...
@@ -118,12 +144,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -118,12 +144,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert
not
provide_description
# not implemented.
assert
not
provide_description
# not implemented.
if
provide_description
is
not
None
:
if
provide_description
is
not
None
:
# nudge people to not specify it at all
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
task_dict_items
=
[
task_dict_items
=
[
(
name
,
task
)
(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
for
name
,
task
in
task_dict
.
items
()
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
]
]
results
=
collections
.
defaultdict
(
dict
)
results
=
collections
.
defaultdict
(
dict
)
...
@@ -158,15 +186,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -158,15 +186,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd
.
seed
(
42
)
rnd
.
seed
(
42
)
rnd
.
shuffle
(
task_docs
)
rnd
.
shuffle
(
task_docs
)
description
=
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
description
=
(
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
docs
[(
task_name
,
doc_id
)]
=
doc
docs
[(
task_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
...
@@ -189,7 +218,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -189,7 +218,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print
(
"Running"
,
reqtype
,
"requests"
)
print
(
"Running"
,
reqtype
,
"requests"
)
resps
=
getattr
(
lm
,
reqtype
)([
req
.
args
for
req
in
reqs
])
resps
=
getattr
(
lm
,
reqtype
)([
req
.
args
for
req
in
reqs
])
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)]
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
]
for
resp
,
(
i
,
task_name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
for
resp
,
(
i
,
task_name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
process_res_queue
[(
task_name
,
doc_id
)].
append
((
i
,
resp
))
process_res_queue
[(
task_name
,
doc_id
)].
append
((
i
,
resp
))
...
@@ -208,24 +239,28 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
...
@@ -208,24 +239,28 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
for
metric
,
value
in
metrics
.
items
():
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_name
,
metric
)].
append
(
value
)
vals
[(
task_name
,
metric
)].
append
(
value
)
task_name
,
prompt_name
=
task_name
.
split
(
"+"
)
results
[
task_name
][
"task_name"
]
=
task_name
results
[
task_name
][
"prompt_name"
]
=
prompt_name
# aggregate results
# aggregate results
for
(
task_name
,
metric
),
items
in
vals
.
items
():
for
(
task_name
,
metric
),
items
in
vals
.
items
():
task
=
task_dict
[
task_name
]
task
=
task_dict
[
task_name
]
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
# so we run them less iterations. still looking for a cleaner way to do this
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
metric
=
task
.
aggregation
()[
metric
],
metric
=
task
.
aggregation
()[
metric
],
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
)
)
if
stderr
is
not
None
:
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
return
{
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)}
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)
}
def
make_table
(
result_dict
):
def
make_table
(
result_dict
):
...
@@ -247,9 +282,9 @@ def make_table(result_dict):
...
@@ -247,9 +282,9 @@ def make_table(result_dict):
if
m
+
"_stderr"
in
dic
:
if
m
+
"_stderr"
in
dic
:
se
=
dic
[
m
+
"_stderr"
]
se
=
dic
[
m
+
"_stderr"
]
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
'±'
,
'
%.4f
'
%
se
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
"±"
,
"
%.4f
"
%
se
])
else
:
else
:
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
''
,
''
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
""
,
""
])
k
=
""
k
=
""
version
=
""
version
=
""
md_writer
.
value_matrix
=
values
md_writer
.
value_matrix
=
values
...
...
lm_eval/tasks/__init__.py
View file @
88745155
from
promptsource.templates
import
DatasetTemplates
from
pprint
import
pprint
from
pprint
import
pprint
from
typing
import
List
,
Union
from
typing
import
List
,
Union
...
@@ -58,8 +60,8 @@ from . import storycloze
...
@@ -58,8 +60,8 @@ from . import storycloze
# 6 total
# 6 total
gpt3_translation_benchmarks
=
{
gpt3_translation_benchmarks
=
{
"wmt14"
:
[
'
en-fr
'
,
'
fr-en
'
],
# French
"wmt14"
:
[
"
en-fr
"
,
"
fr-en
"
],
# French
"wmt16"
:
[
'
en-ro
'
,
'
ro-en
'
,
'
de-en
'
,
'
en-de
'
],
# German, Romanian
"wmt16"
:
[
"
en-ro
"
,
"
ro-en
"
,
"
de-en
"
,
"
en-de
"
],
# German, Romanian
}
}
...
@@ -67,7 +69,7 @@ gpt3_translation_benchmarks = {
...
@@ -67,7 +69,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks
=
{
selected_translation_benchmarks
=
{
**
gpt3_translation_benchmarks
,
**
gpt3_translation_benchmarks
,
"wmt20"
:
sacrebleu
.
get_langpairs_for_testset
(
"wmt20"
),
"wmt20"
:
sacrebleu
.
get_langpairs_for_testset
(
"wmt20"
),
"iwslt17"
:
[
'
en-ar
'
,
'
ar-en
'
]
# Arabic
"iwslt17"
:
[
"
en-ar
"
,
"
ar-en
"
],
# Arabic
}
}
# 319 total
# 319 total
...
@@ -91,7 +93,7 @@ TASK_REGISTRY = {
...
@@ -91,7 +93,7 @@ TASK_REGISTRY = {
"rte"
:
glue
.
RTE
,
"rte"
:
glue
.
RTE
,
"qnli"
:
glue
.
QNLI
,
"qnli"
:
glue
.
QNLI
,
"qqp"
:
glue
.
QQP
,
"qqp"
:
glue
.
QQP
,
#"stsb": glue.STSB, # not implemented yet
#
"stsb": glue.STSB, # not implemented yet
"sst"
:
glue
.
SST
,
"sst"
:
glue
.
SST
,
"wnli"
:
glue
.
WNLI
,
"wnli"
:
glue
.
WNLI
,
# SuperGLUE
# SuperGLUE
...
@@ -102,34 +104,26 @@ TASK_REGISTRY = {
...
@@ -102,34 +104,26 @@ TASK_REGISTRY = {
"record"
:
superglue
.
ReCoRD
,
"record"
:
superglue
.
ReCoRD
,
"wic"
:
superglue
.
WordsInContext
,
"wic"
:
superglue
.
WordsInContext
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
# Order by benchmark/genre?
# Order by benchmark/genre?
"coqa"
:
coqa
.
CoQA
,
"coqa"
:
coqa
.
CoQA
,
"drop"
:
drop
.
DROP
,
"drop"
:
drop
.
DROP
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
# multilingual lambada
# multilingual lambada
**
lambada_multilingual
.
construct_tasks
(),
**
lambada_multilingual
.
construct_tasks
(),
"wikitext"
:
wikitext
.
WikiText
,
"wikitext"
:
wikitext
.
WikiText
,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa"
:
piqa
.
PiQA
,
"piqa"
:
piqa
.
PiQA
,
"prost"
:
prost
.
PROST
,
"prost"
:
prost
.
PROST
,
"mc_taco"
:
mc_taco
.
MCTACO
,
"mc_taco"
:
mc_taco
.
MCTACO
,
# Science related
# Science related
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"sciq"
:
sciq
.
SciQ
,
"qasper"
:
qasper
.
QASPER
,
"qasper"
:
qasper
.
QASPER
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"triviaqa"
:
triviaqa
.
TriviaQA
,
"triviaqa"
:
triviaqa
.
TriviaQA
,
"arc_easy"
:
arc
.
ARCEasy
,
"arc_easy"
:
arc
.
ARCEasy
,
"arc_challenge"
:
arc
.
ARCChallenge
,
"arc_challenge"
:
arc
.
ARCChallenge
,
...
@@ -150,21 +144,17 @@ TASK_REGISTRY = {
...
@@ -150,21 +144,17 @@ TASK_REGISTRY = {
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r2"
:
anli
.
ANLIRound2
,
"anli_r2"
:
anli
.
ANLIRound2
,
"anli_r3"
:
anli
.
ANLIRound3
,
"anli_r3"
:
anli
.
ANLIRound3
,
"ethics_cm"
:
hendrycks_ethics
.
EthicsCM
,
"ethics_cm"
:
hendrycks_ethics
.
EthicsCM
,
"ethics_deontology"
:
hendrycks_ethics
.
EthicsDeontology
,
"ethics_deontology"
:
hendrycks_ethics
.
EthicsDeontology
,
"ethics_justice"
:
hendrycks_ethics
.
EthicsJustice
,
"ethics_justice"
:
hendrycks_ethics
.
EthicsJustice
,
"ethics_utilitarianism_original"
:
hendrycks_ethics
.
EthicsUtilitarianismOriginal
,
"ethics_utilitarianism_original"
:
hendrycks_ethics
.
EthicsUtilitarianismOriginal
,
"ethics_utilitarianism"
:
hendrycks_ethics
.
EthicsUtilitarianism
,
"ethics_utilitarianism"
:
hendrycks_ethics
.
EthicsUtilitarianism
,
"ethics_virtue"
:
hendrycks_ethics
.
EthicsVirtue
,
"ethics_virtue"
:
hendrycks_ethics
.
EthicsVirtue
,
"truthfulqa_mc"
:
truthfulqa
.
TruthfulQAMultipleChoice
,
"truthfulqa_mc"
:
truthfulqa
.
TruthfulQAMultipleChoice
,
"truthfulqa_gen"
:
truthfulqa
.
TruthfulQAGeneration
,
"truthfulqa_gen"
:
truthfulqa
.
TruthfulQAGeneration
,
# dialogue
# dialogue
"mutual"
:
mutual
.
MuTual
,
"mutual"
:
mutual
.
MuTual
,
"mutual_plus"
:
mutual
.
MuTualPlus
,
"mutual_plus"
:
mutual
.
MuTualPlus
,
# math
# math
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
...
@@ -175,7 +165,6 @@ TASK_REGISTRY = {
...
@@ -175,7 +165,6 @@ TASK_REGISTRY = {
"math_precalc"
:
hendrycks_math
.
MathPrecalculus
,
"math_precalc"
:
hendrycks_math
.
MathPrecalculus
,
"math_asdiv"
:
asdiv
.
Asdiv
,
"math_asdiv"
:
asdiv
.
Asdiv
,
"gsm8k"
:
gsm8k
.
GradeSchoolMath8K
,
"gsm8k"
:
gsm8k
.
GradeSchoolMath8K
,
# arithmetic
# arithmetic
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2ds"
:
arithmetic
.
Arithmetic2DMinus
,
"arithmetic_2ds"
:
arithmetic
.
Arithmetic2DMinus
,
...
@@ -189,22 +178,18 @@ TASK_REGISTRY = {
...
@@ -189,22 +178,18 @@ TASK_REGISTRY = {
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
# TODO Perhaps make these groups of tasks
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
# hendrycksTest (57 tasks)
**
hendrycks_test
.
create_all_tasks
(),
**
hendrycks_test
.
create_all_tasks
(),
# e.g. wmt14-fr-en
# e.g. wmt14-fr-en
**
translation
.
create_tasks_from_benchmarks
(
gpt3_translation_benchmarks
),
**
translation
.
create_tasks_from_benchmarks
(
gpt3_translation_benchmarks
),
# chef's selection, mostly wmt20
# chef's selection, mostly wmt20
**
translation
.
create_tasks_from_benchmarks
(
selected_translation_benchmarks
),
**
translation
.
create_tasks_from_benchmarks
(
selected_translation_benchmarks
),
# Word Scrambling and Manipulation Tasks
# Word Scrambling and Manipulation Tasks
"anagrams1"
:
unscramble
.
Anagrams1
,
"anagrams1"
:
unscramble
.
Anagrams1
,
"anagrams2"
:
unscramble
.
Anagrams2
,
"anagrams2"
:
unscramble
.
Anagrams2
,
"cycle_letters"
:
unscramble
.
CycleLetters
,
"cycle_letters"
:
unscramble
.
CycleLetters
,
"random_insertion"
:
unscramble
.
RandomInsertion
,
"random_insertion"
:
unscramble
.
RandomInsertion
,
"reversed_words"
:
unscramble
.
ReversedWords
,
"reversed_words"
:
unscramble
.
ReversedWords
,
# Pile
# Pile
"pile_arxiv"
:
pile
.
PileArxiv
,
"pile_arxiv"
:
pile
.
PileArxiv
,
"pile_books3"
:
pile
.
PileBooks3
,
"pile_books3"
:
pile
.
PileBooks3
,
...
@@ -228,7 +213,6 @@ TASK_REGISTRY = {
...
@@ -228,7 +213,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc"
:
pile
.
PileUbuntuIrc
,
"pile_ubuntu-irc"
:
pile
.
PileUbuntuIrc
,
"pile_wikipedia"
:
pile
.
PileWikipedia
,
"pile_wikipedia"
:
pile
.
PileWikipedia
,
"pile_youtubesubtitles"
:
pile
.
PileYoutubeSubtitles
,
"pile_youtubesubtitles"
:
pile
.
PileYoutubeSubtitles
,
# BLiMP
# BLiMP
"blimp_adjunct_island"
:
blimp
.
BlimpAdjunctIsland
,
"blimp_adjunct_island"
:
blimp
.
BlimpAdjunctIsland
,
"blimp_anaphor_gender_agreement"
:
blimp
.
BlimpAnaphorGenderAgreement
,
"blimp_anaphor_gender_agreement"
:
blimp
.
BlimpAnaphorGenderAgreement
,
...
@@ -297,7 +281,6 @@ TASK_REGISTRY = {
...
@@ -297,7 +281,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance"
:
blimp
.
BlimpWhVsThatNoGapLongDistance
,
"blimp_wh_vs_that_no_gap_long_distance"
:
blimp
.
BlimpWhVsThatNoGapLongDistance
,
"blimp_wh_vs_that_with_gap"
:
blimp
.
BlimpWhVsThatWithGap
,
"blimp_wh_vs_that_with_gap"
:
blimp
.
BlimpWhVsThatWithGap
,
"blimp_wh_vs_that_with_gap_long_distance"
:
blimp
.
BlimpWhVsThatWithGapLongDistance
,
"blimp_wh_vs_that_with_gap_long_distance"
:
blimp
.
BlimpWhVsThatWithGapLongDistance
,
# Requires manual download of data.
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
# "storycloze_2018": storycloze.StoryCloze2018,
...
@@ -323,17 +306,41 @@ def get_task_name_from_object(task_object):
...
@@ -323,17 +306,41 @@ def get_task_name_from_object(task_object):
return
name
return
name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
return
(
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
)
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
lm_eval
.
base
.
Task
]]):
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
lm_eval
.
base
.
Task
]]):
task_name_dict
=
{
task_name_dict
=
{
task_name
:
get_task
(
task_name
)()
task_name
:
get_task
(
task_name
)()
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
}
}
task_name_from_object_dict
=
{
task_name_from_object_dict
=
{
get_task_name_from_object
(
task_object
):
task_object
get_task_name_from_object
(
task_object
):
task_object
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
}
}
assert
set
(
task_name_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
()))
assert
set
(
task_name_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
()))
return
{
**
task_name_dict
,
**
task_name_from_object_dict
}
return
{
**
task_name_dict
,
**
task_name_from_object_dict
}
def
get_task_dict_promptsource
(
task_name_list
:
List
[
str
]):
"""Loads a task instance for each prompt written for that task."""
task_name_dict
=
{}
for
task_name
in
task_name_list
:
assert
isinstance
(
task_name
,
str
)
task_prompts
=
DatasetTemplates
(
task_name
)
for
prompt_name
in
task_prompts
.
all_template_names
:
prompt
=
task_prompts
[
prompt_name
]
# NOTE: We choose a sep that can be easily split.
task_name_dict
[
f
"
{
task_name
}
+
{
prompt_name
}
"
]
=
get_task
(
task_name
)(
prompt
=
prompt
)
return
task_name_dict
lm_eval/tasks/coqa.py
View file @
88745155
...
@@ -51,44 +51,22 @@ class CoQA(Task):
...
@@ -51,44 +51,22 @@ class CoQA(Task):
def
test_docs
(
self
):
def
test_docs
(
self
):
pass
pass
def
doc_to_text
(
self
,
doc
):
# @classmethod
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# def get_answers(cls, doc, turn_id):
# and a question qi, the task is to predict the answer ai
# # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
# answers = []
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
][
"input_text"
],
doc
[
"answers"
][
"input_text"
][:
-
1
]):
# omit target answer ai
# answer_forturn = doc["answers"]["input_text"][turn_id - 1]
question
=
f
"Q:
{
q
}
\n\n
"
# answers.append(answer_forturn)
answer
=
f
"A:
{
a
}
\n\n
"
if
a
is
not
None
else
"A:"
doc_text
+=
question
+
answer
# additional_answers = doc.get("additional_answers")
return
doc_text
# if additional_answers:
# for key in additional_answers:
@
classmethod
# additional_answer_for_turn = additional_answers[key]["input_text"][
def
get_answers
(
cls
,
doc
,
turn_id
):
# turn_id - 1
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
# ]
answers
=
[]
# if additional_answer_for_turn.lower() not in map(str.lower, answers):
answer_forturn
=
doc
[
"answers"
][
"input_text"
][
turn_id
-
1
]
# answers.append(additional_answer_for_turn)
answers
.
append
(
answer_forturn
)
# return answers
additional_answers
=
doc
.
get
(
"additional_answers"
)
if
additional_answers
:
for
key
in
additional_answers
:
additional_answer_for_turn
=
additional_answers
[
key
][
"input_text"
][
turn_id
-
1
]
if
additional_answer_for_turn
.
lower
()
not
in
map
(
str
.
lower
,
answers
):
answers
.
append
(
additional_answer_for_turn
)
return
answers
@
classmethod
def
get_answer_choice
(
self
,
raw_text
):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if
raw_text
==
"unknown"
:
return
'0'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"yes"
:
return
'1'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"no"
:
return
'2'
return
'3'
# Not a yes/no question
@
staticmethod
@
staticmethod
def
compute_scores
(
gold_list
,
pred
):
def
compute_scores
(
gold_list
,
pred
):
...
@@ -98,25 +76,23 @@ class CoQA(Task):
...
@@ -98,25 +76,23 @@ class CoQA(Task):
em_sum
=
0.0
em_sum
=
0.0
if
len
(
gold_list
)
>
1
:
if
len
(
gold_list
)
>
1
:
for
i
in
range
(
len
(
gold_list
)):
for
i
in
range
(
len
(
gold_list
)):
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
# predictions compared against (n) golds and take maximum
# predictions compared against (n) golds and take maximum
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
else
:
else
:
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_list
)
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_list
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_list
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_list
)
return
{
'em'
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
'f1'
:
f1_sum
/
max
(
1
,
len
(
gold_list
))}
return
{
"em"
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
# Default to prediction of last turn.
}
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
][
"input_text"
])
raw_text
=
doc
[
'answers'
][
"input_text"
][
turnid
-
1
]
return
" "
+
raw_text
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
:param doc:
:param doc:
...
@@ -126,7 +102,7 @@ class CoQA(Task):
...
@@ -126,7 +102,7 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
"""
"""
cont_request
=
rf
.
greedy_until
(
ctx
,
[
'
\n
Q:
'
])
cont_request
=
rf
.
greedy_until
(
ctx
,
[
"
\n
Q:
"
])
return
cont_request
return
cont_request
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
...
@@ -139,15 +115,18 @@ class CoQA(Task):
...
@@ -139,15 +115,18 @@ class CoQA(Task):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
turn_id
=
len
(
doc
[
"questions"
][
"input_text"
])
target
=
self
.
doc_to_target
(
doc
).
strip
()
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
pred
=
results
[
0
].
strip
().
split
(
"
\n
"
)[
0
]
pred
=
results
[
0
].
strip
().
split
(
'
\n
'
)[
0
]
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
scores
=
self
.
compute_scores
(
gold_list
,
pred
)
# TODO: Add HF metrics mapped from promptsource metadata.
scores
=
self
.
compute_scores
([
target
],
pred
)
return
{
return
{
"f1"
:
scores
[
'
f1
'
],
"f1"
:
scores
[
"
f1
"
],
"em"
:
scores
[
'
em
'
],
"em"
:
scores
[
"
em
"
],
}
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
...
...
lm_eval/tasks/drop.py
View file @
88745155
...
@@ -70,21 +70,26 @@ class DROP(Task):
...
@@ -70,21 +70,26 @@ class DROP(Task):
@
classmethod
@
classmethod
def
get_answers
(
cls
,
qa
):
def
get_answers
(
cls
,
qa
):
def
_flatten_validated_answers
(
validated_answers
):
def
_flatten_validated_answers
(
validated_answers
):
"""
Flattens a dict of lists of validated answers.
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
"""
vas
=
[]
vas
=
[]
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
vas
.
append
({
vas
.
append
(
{
"number"
:
validated_answers
[
"number"
][
i
],
"number"
:
validated_answers
[
"number"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
})
}
)
return
vas
return
vas
answers
=
[]
answers
=
[]
answers_set
=
set
()
answers_set
=
set
()
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
])
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
]
)
for
candidate
in
candidates
:
for
candidate
in
candidates
:
answer
=
cls
.
parse_answer
(
candidate
)
answer
=
cls
.
parse_answer
(
candidate
)
if
answer
in
answers_set
:
if
answer
in
answers_set
:
...
@@ -100,15 +105,17 @@ class DROP(Task):
...
@@ -100,15 +105,17 @@ class DROP(Task):
return
(
str
(
answer
[
"number"
]),)
return
(
str
(
answer
[
"number"
]),)
if
answer
[
"spans"
]
!=
[]:
if
answer
[
"spans"
]
!=
[]:
return
tuple
(
answer
[
"spans"
])
return
tuple
(
answer
[
"spans"
])
return
(
" "
.
join
([
answer
[
"date"
][
"day"
],
return
(
answer
[
"date"
][
"month"
],
" "
.
join
(
answer
[
"date"
][
"year"
]]).
strip
(),)
[
answer
[
"date"
][
"day"
],
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]
).
strip
(),
)
def
doc_to_text
(
self
,
doc
):
#
def doc_to_text(self, doc):
return
f
"Passage:
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
#
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def
doc_to_target
(
self
,
doc
):
#
def doc_to_target(self, doc):
return
" "
+
", "
.
join
(
doc
[
"answers"
][
0
])
#
return " " + ", ".join(doc["answers"][0])
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
...
@@ -134,7 +141,13 @@ class DROP(Task):
...
@@ -134,7 +141,13 @@ class DROP(Task):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
preds
,
golds
=
results
,
doc
[
"answers"
]
pred
=
results
[
0
].
strip
()
target
=
self
.
doc_to_target
(
doc
).
strip
()
preds
=
[
pred
]
golds
=
[
target
]
max_em
=
0
max_em
=
0
max_f1
=
0
max_f1
=
0
for
gold_answer
in
golds
:
for
gold_answer
in
golds
:
...
@@ -142,10 +155,7 @@ class DROP(Task):
...
@@ -142,10 +155,7 @@ class DROP(Task):
if
gold_answer
[
0
].
strip
():
if
gold_answer
[
0
].
strip
():
max_em
=
max
(
max_em
,
exact_match
)
max_em
=
max
(
max_em
,
exact_match
)
max_f1
=
max
(
max_f1
,
f1_score
)
max_f1
=
max
(
max_f1
,
f1_score
)
return
{
return
{
"em"
:
max_em
,
"f1"
:
max_f1
}
"em"
:
max_em
,
"f1"
:
max_f1
}
def
get_metrics
(
self
,
predicted
,
gold
):
def
get_metrics
(
self
,
predicted
,
gold
):
"""
"""
...
@@ -158,7 +168,9 @@ class DROP(Task):
...
@@ -158,7 +168,9 @@ class DROP(Task):
predicted_bags
=
self
.
_answer_to_bags
(
predicted
)
predicted_bags
=
self
.
_answer_to_bags
(
predicted
)
gold_bags
=
self
.
_answer_to_bags
(
gold
)
gold_bags
=
self
.
_answer_to_bags
(
gold
)
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
])
==
len
(
gold_bags
[
0
]):
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
]
)
==
len
(
gold_bags
[
0
]):
exact_match
=
1.0
exact_match
=
1.0
else
:
else
:
exact_match
=
0.0
exact_match
=
0.0
...
@@ -190,7 +202,9 @@ class DROP(Task):
...
@@ -190,7 +202,9 @@ class DROP(Task):
for
gold_index
,
gold_item
in
enumerate
(
gold
):
for
gold_index
,
gold_item
in
enumerate
(
gold
):
for
pred_index
,
pred_item
in
enumerate
(
predicted
):
for
pred_index
,
pred_item
in
enumerate
(
predicted
):
if
self
.
_match_numbers_if_present
(
gold_item
,
pred_item
):
if
self
.
_match_numbers_if_present
(
gold_item
,
pred_item
):
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
row_ind
,
col_ind
=
linear_sum_assignment
(
-
scores
)
row_ind
,
col_ind
=
linear_sum_assignment
(
-
scores
)
max_scores
=
np
.
zeros
([
max
(
len
(
gold
),
len
(
predicted
))])
max_scores
=
np
.
zeros
([
max
(
len
(
gold
),
len
(
predicted
))])
...
@@ -256,7 +270,11 @@ class DROP(Task):
...
@@ -256,7 +270,11 @@ class DROP(Task):
def
_normalize
(
self
,
answer
):
def
_normalize
(
self
,
answer
):
tokens
=
[
tokens
=
[
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))))
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))
)
)
for
token
in
self
.
_tokenize
(
answer
)
for
token
in
self
.
_tokenize
(
answer
)
]
]
tokens
=
[
token
for
token
in
tokens
if
token
.
strip
()]
tokens
=
[
token
for
token
in
tokens
if
token
.
strip
()]
...
@@ -269,10 +287,7 @@ class DROP(Task):
...
@@ -269,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
return
{
return
{
"em"
:
mean
,
"f1"
:
mean
}
"em"
:
mean
,
"f1"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
...
@@ -280,7 +295,4 @@ class DROP(Task):
...
@@ -280,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
return
{
return
{
"em"
:
True
,
"f1"
:
True
}
"em"
:
True
,
"f1"
:
True
}
lm_eval/tasks/race.py
View file @
88745155
...
@@ -40,7 +40,7 @@ class RACE(Task):
...
@@ -40,7 +40,7 @@ class RACE(Task):
DATASET_NAME
=
"high"
DATASET_NAME
=
"high"
cache
=
{}
cache
=
{}
letter_to_num
=
{
'A'
:
0
,
'B'
:
1
,
'C'
:
2
,
'D'
:
3
}
letter_to_num
=
{
"A"
:
0
,
"B"
:
1
,
"C"
:
2
,
"D"
:
3
}
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
return
True
return
True
...
@@ -59,17 +59,27 @@ class RACE(Task):
...
@@ -59,17 +59,27 @@ class RACE(Task):
# is shown that one document is made per passage.
# is shown that one document is made per passage.
r
=
collections
.
defaultdict
(
list
)
r
=
collections
.
defaultdict
(
list
)
for
item
in
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)[
set
]:
for
item
in
datasets
.
load_dataset
(
r
[
item
[
'article'
]].
append
(
item
)
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)[
set
]:
res
=
list
(
r
.
values
()
>>
each
(
lambda
x
:
{
r
[
item
[
"article"
]].
append
(
item
)
'article'
:
x
[
0
][
'article'
],
'problems'
:
x
>>
each
(
lambda
y
:
{
res
=
list
(
'question'
:
y
[
'question'
],
r
.
values
()
'answer'
:
y
[
'answer'
],
>>
each
(
'options'
:
y
[
'options'
],
lambda
x
:
{
})
"article"
:
x
[
0
][
"article"
],
}))
"problems"
:
x
>>
each
(
lambda
y
:
{
"question"
:
y
[
"question"
],
"answer"
:
y
[
"answer"
],
"options"
:
y
[
"options"
],
}
),
}
)
)
self
.
cache
[
set
]
=
res
self
.
cache
[
set
]
=
res
return
res
return
res
...
@@ -85,45 +95,44 @@ class RACE(Task):
...
@@ -85,45 +95,44 @@ class RACE(Task):
@
classmethod
@
classmethod
def
get_answer_option
(
cls
,
problem
):
def
get_answer_option
(
cls
,
problem
):
answer
=
cls
.
letter_to_num
[
problem
[
'
answer
'
]]
answer
=
cls
.
letter_to_num
[
problem
[
"
answer
"
]]
return
problem
[
'
options
'
][
answer
]
return
problem
[
"
options
"
][
answer
]
@
classmethod
@
classmethod
def
last_problem
(
cls
,
doc
):
def
last_problem
(
cls
,
doc
):
return
doc
[
'problems'
][
-
1
]
return
doc
[
"problems"
][
-
1
]
def
doc_to_text
(
self
,
doc
):
# def doc_to_text(self, doc):
text
=
'Article: '
+
doc
[
'article'
]
+
'
\n\n
'
# text = 'Article: ' + doc['article'] + '\n\n'
for
problem
in
doc
[
'problems'
][:
-
1
]:
# for problem in doc['problems'][:-1]:
if
problem
[
'question'
][
-
6
:]
==
' _ .'
:
# if problem['question'][-6:] == ' _ .':
text
+=
problem
[
'question'
][
-
5
:]
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
# text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
else
:
# else:
question
=
'Question: '
+
problem
[
'question'
]
+
'
\n
'
# question = 'Question: ' + problem['question'] + '\n'
answer
=
'Answer: '
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
# answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
text
+=
question
+
answer
# text += question + answer
text
+=
self
.
last_problem
(
doc
)[
'question'
]
# text += self.last_problem(doc)['question']
return
text
# return text
def
doc_to_target
(
self
,
doc
):
# def doc_to_target(self, doc):
return
" "
+
self
.
get_answer_option
(
self
.
last_problem
(
doc
))
# return " " + self.get_answer_option(self.last_problem(doc))
def
construct_requests
(
self
,
doc
,
ctx
):
# def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
# """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
# Requests which will be sent to the LM.
:param doc:
# :param doc:
The document as returned from training_docs, validation_docs, or test_docs.
# The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
# :param ctx: str
The context string, generated by fewshot_context. This includes the natural
# The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
# language description, as well as the few shot examples, and the question
part of the document for `doc`.
# part of the document for `doc`.
"""
# """
problem
=
self
.
last_problem
(
doc
)
# problem = self.last_problem(doc)
ll_choices
=
[
# ll_choices = [
rf
.
loglikelihood
(
ctx
,
" "
+
problem
[
'options'
][
i
])[
0
]
# rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)
for
i
in
range
(
4
)
# ]
]
# return ll_choices
return
ll_choices
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
...
@@ -135,11 +144,11 @@ class RACE(Task):
...
@@ -135,11 +144,11 @@ class RACE(Task):
:param results:
:param results:
The results of the requests created in construct_requests.
The results of the requests created in construct_requests.
"""
"""
gold
=
self
.
letter_to_num
[
self
.
last_problem
(
doc
)[
'answer'
]]
#
gold
=
self
.
letter_to_num
[
self
.
doc_to_target
(
doc
)]
# gold = self.letter_to_num[self.last_problem(doc)["answer"]]
pred
=
np
.
argmax
(
results
)
pred
=
np
.
argmax
(
results
)
return
{
return
{
"acc"
:
int
(
pred
==
gold
)}
"acc"
:
int
(
pred
==
gold
)
}
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
"""
...
@@ -147,9 +156,7 @@ class RACE(Task):
...
@@ -147,9 +156,7 @@ class RACE(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
functions that aggregate a list of metrics
"""
"""
return
{
return
{
"acc"
:
mean
}
"acc"
:
mean
}
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
...
@@ -157,6 +164,4 @@ class RACE(Task):
...
@@ -157,6 +164,4 @@ class RACE(Task):
A dictionary where keys are the names of submetrics and values are
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
whether a higher value of the submetric is better
"""
"""
return
{
return
{
"acc"
:
True
}
"acc"
:
True
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment