Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
88745155
Commit
88745155
authored
Apr 25, 2022
by
cjlovering
Browse files
Initial integration
parent
6caa0afd
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
480 additions
and
318 deletions
+480
-318
lm_eval/base.py
lm_eval/base.py
+220
-96
lm_eval/evaluator.py
lm_eval/evaluator.py
+70
-35
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+42
-35
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+39
-60
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+42
-30
lm_eval/tasks/race.py
lm_eval/tasks/race.py
+67
-62
No files found.
lm_eval/base.py
View file @
88745155
import
abc
from
typing
import
Iterable
import
numpy
as
np
import
random
import
re
...
...
@@ -118,7 +119,6 @@ class LM(abc.ABC):
class
BaseLM
(
LM
):
@
property
@
abstractmethod
def
eot_token_id
(
self
):
...
...
@@ -145,13 +145,16 @@ class BaseLM(LM):
pass
@
abstractmethod
def
tok_encode
(
self
,
string
:
str
):
pass
def
tok_encode
(
self
,
string
:
str
):
pass
@
abstractmethod
def
tok_decode
(
self
,
tokens
:
Iterable
[
int
]):
pass
def
tok_decode
(
self
,
tokens
:
Iterable
[
int
]):
pass
@
abstractmethod
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
pass
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
pass
@
abstractmethod
def
_model_call
(
self
,
inps
):
...
...
@@ -187,19 +190,26 @@ class BaseLM(LM):
# TODO: automatic batch size detection for vectorization
loglikelihoods
=
[]
for
string
,
in
tqdm
(
requests
):
rolling_token_windows
=
list
(
map
(
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
for
(
string
,)
in
tqdm
(
requests
):
rolling_token_windows
=
list
(
map
(
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
)))
),
)
)
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
,
disable_tqdm
=
True
)
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
,
disable_tqdm
=
True
)
# discard is_greedy
string_nll
=
[
x
[
0
]
for
x
in
string_nll
]
...
...
@@ -226,7 +236,9 @@ class BaseLM(LM):
# TODO: automatic (variable) batch size detection for vectorization
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
inps
=
[]
cont_toks_list
=
[]
inplens
=
[]
...
...
@@ -252,44 +264,60 @@ class BaseLM(LM):
# when too long to fit in context, truncate from the left
inp
=
torch
.
tensor
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
):][:
-
1
],
dtype
=
torch
.
long
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
],
dtype
=
torch
.
long
,
).
to
(
self
.
device
)
inplen
,
=
inp
.
shape
(
inplen
,
)
=
inp
.
shape
cont
=
continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length
=
padding_length
if
padding_length
is
not
None
else
inplen
padding_length
=
(
padding_length
if
padding_length
is
not
None
else
inplen
)
# pad length from seq to padding_length
inp
=
torch
.
cat
([
inp
=
torch
.
cat
(
[
inp
,
# [seq]
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
).
to
(
inp
.
device
)
# [padding_length - seq]
],
dim
=
0
)
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
).
to
(
inp
.
device
),
# [padding_length - seq]
],
dim
=
0
,
)
inps
.
append
(
inp
.
unsqueeze
(
0
))
# [1, padding_length]
cont_toks_list
.
append
(
cont
)
inplens
.
append
(
inplen
)
batched_inps
=
torch
.
cat
(
inps
,
dim
=
0
)
# [batch, padding_length
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
),
dim
=-
1
).
cpu
()
# [batch, padding_length, vocab]
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
),
dim
=-
1
).
cpu
()
# [batch, padding_length, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
\
in
zip
(
chunk
,
multi_logits
,
inps
,
inplens
,
cont_toks_list
):
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
inps
,
inplens
,
cont_toks_list
):
# Slice to original seq length
contlen
=
len
(
cont_toks
)
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
# [1, seq]
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
# [1, seq]
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
# Answer: (log prob, is-exact-match)
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
...
...
@@ -319,13 +347,17 @@ class BaseLM(LM):
if
isinstance
(
until
,
str
):
until
=
[
until
]
primary_until
,
=
self
.
tok_encode
(
until
[
0
])
(
primary_until
,
)
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
([
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]).
to
(
self
.
device
)
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]:])
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]
:])
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
...
...
@@ -383,7 +415,7 @@ class Task(abc.ABC):
self
.
_fewshot_docs
=
None
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
"""
Downloads and returns the task dataset.
"""Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
:param data_dir: str
...
...
@@ -412,7 +444,7 @@ class Task(abc.ABC):
name
=
self
.
DATASET_NAME
,
data_dir
=
data_dir
,
cache_dir
=
cache_dir
,
download_mode
=
download_mode
download_mode
=
download_mode
,
)
@
abstractmethod
...
...
@@ -478,7 +510,7 @@ class Task(abc.ABC):
@
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
...
...
@@ -523,15 +555,19 @@ class Task(abc.ABC):
def
fewshot_description
(
self
):
import
warnings
warnings
.
warn
(
"`fewshot_description` will be removed in futures versions. Pass "
"any custom descriptions to the `evaluate` function instead."
,
DeprecationWarning
)
DeprecationWarning
,
)
return
""
@
utils
.
positional_deprecated
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
""" Returns a fewshot context string that is made up of a prepended description
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
...
...
@@ -548,7 +584,9 @@ class Task(abc.ABC):
:returns: str
The fewshot context.
"""
assert
rnd
is
not
None
,
"A `random.Random` generator argument must be provided to `rnd`"
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`"
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
...
...
@@ -556,7 +594,9 @@ class Task(abc.ABC):
)
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description
=
description
+
"
\n\n
"
if
description
else
""
...
...
@@ -569,7 +609,9 @@ class Task(abc.ABC):
else
:
if
self
.
_fewshot_docs
is
None
:
self
.
_fewshot_docs
=
list
(
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
)
fewshotex
=
rnd
.
sample
(
self
.
_fewshot_docs
,
num_fewshot
+
1
)
...
...
@@ -577,23 +619,90 @@ class Task(abc.ABC):
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
labeled_examples
=
"
\n\n
"
.
join
(
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
)
+
"
\n\n
"
labeled_examples
=
(
"
\n\n
"
.
join
(
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
)
+
"
\n\n
"
)
example
=
self
.
doc_to_text
(
doc
)
return
description
+
labeled_examples
+
example
class
MultipleChoiceTask
(
Task
):
class
PromptSourceTask
(
Task
):
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
prompt
=
None
):
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
self
.
prompt
=
prompt
def
doc_to_target
(
self
,
doc
):
_
,
target
=
prompt
.
apply
(
doc
)
return
f
"
{
target
}
"
def
doc_to_text
(
self
,
doc
):
text
,
_
=
prompt
.
apply
(
doc
)
return
text
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
_requests
=
[]
if
self
.
prompt
.
metadata
.
choices_in_prompt
:
for
answer_choice
in
prompt
.
get_fixed_answer_choices_list
():
ll_answer_choice
,
_
=
rf
.
loglikelihood
(
ctx
,
f
"
{
answer_choice
}
"
)
_requests
.
append
(
ll_answer_choice
)
else
:
# TODO(Albert): What is the stop symbol? Is it model specific?
ll_greedy
,
_
=
rf
.
greedy_until
(
ctx
,
[
"
\n
Q:"
])
_requests
.
append
(
ll_greedy
)
return
_requests
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
raise
NotImplementedError
(
"Implement process results using the `prompt.metadata.metrics`. See below."
)
if
self
.
prompt
.
metadata
.
choices_in_prompt
:
for
result
,
answer_choice
in
zip
(
prompt
.
get_fixed_answer_choices_list
(),
results
):
pass
else
:
continuation
=
results
# Map metric name to HF metric.
# TODO(Albert): What is Other?
metric_names
=
prompt
.
metadata
.
metrics
class
MultipleChoiceTask
(
Task
):
def
doc_to_target
(
self
,
doc
):
return
" "
+
doc
[
'
choices
'
][
doc
[
'
gold
'
]]
return
" "
+
doc
[
"
choices
"
][
doc
[
"
gold
"
]]
def
construct_requests
(
self
,
doc
,
ctx
):
lls
=
[
rf
.
loglikelihood
(
ctx
,
" {}"
.
format
(
choice
))[
0
]
for
choice
in
doc
[
'choices'
]
rf
.
loglikelihood
(
ctx
,
" {}"
.
format
(
choice
))[
0
]
for
choice
in
doc
[
"choices"
]
]
return
lls
...
...
@@ -601,9 +710,9 @@ class MultipleChoiceTask(Task):
def
process_results
(
self
,
doc
,
results
):
gold
=
doc
[
"gold"
]
acc
=
1.
if
np
.
argmax
(
results
)
==
gold
else
0.
acc
=
1.
0
if
np
.
argmax
(
results
)
==
gold
else
0.
0
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
doc
[
"choices"
]])
acc_norm
=
1.
if
np
.
argmax
(
results
/
completion_len
)
==
gold
else
0.
acc_norm
=
1.
0
if
np
.
argmax
(
results
/
completion_len
)
==
gold
else
0.
0
return
{
"acc"
:
acc
,
...
...
@@ -624,7 +733,6 @@ class MultipleChoiceTask(Task):
class
PerplexityTask
(
Task
,
abc
.
ABC
):
def
has_training_docs
(
self
):
return
False
...
...
@@ -632,9 +740,15 @@ class PerplexityTask(Task, abc.ABC):
assert
k
==
0
return
[]
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
num_fewshot
==
0
,
"The number of fewshot examples must be 0 for perplexity tasks."
assert
rnd
is
not
None
,
"A `random.Random` generator argument must be provided to `rnd`."
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
assert
(
num_fewshot
==
0
),
"The number of fewshot examples must be 0 for perplexity tasks."
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`."
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
...
...
@@ -642,7 +756,9 @@ class PerplexityTask(Task, abc.ABC):
)
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
return
""
...
...
@@ -665,7 +781,7 @@ class PerplexityTask(Task, abc.ABC):
return
req
def
process_results
(
self
,
doc
,
results
):
loglikelihood
,
=
results
(
loglikelihood
,
)
=
results
words
=
self
.
count_words
(
doc
)
bytes_
=
self
.
count_bytes
(
doc
)
return
{
...
...
@@ -687,13 +803,13 @@ class PerplexityTask(Task, abc.ABC):
@
classmethod
def
count_words
(
cls
,
doc
):
"""
Downstream tasks with custom word boundaries should override this!
"""
"""Downstream tasks with custom word boundaries should override this!"""
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
def
hash_args
(
attr
,
args
):
dat
=
json
.
dumps
([
attr
]
+
list
(
args
))
return
hashlib
.
sha256
(
dat
.
encode
(
'
utf-8
'
)).
hexdigest
()
return
hashlib
.
sha256
(
dat
.
encode
(
"
utf-8
"
)).
hexdigest
()
class
CacheHook
:
...
...
@@ -764,6 +880,7 @@ class CachingLM:
self
.
dbdict
.
commit
()
return
res
return
fn
def
get_cache_hook
(
self
):
...
...
@@ -771,16 +888,18 @@ class CachingLM:
REQUEST_RETURN_LENGTHS
=
{
'
loglikelihood
'
:
2
,
'
greedy_until
'
:
None
,
'
loglikelihood_rolling
'
:
None
,
"
loglikelihood
"
:
2
,
"
greedy_until
"
:
None
,
"
loglikelihood_rolling
"
:
None
,
}
class
Request
:
def
__init__
(
self
,
request_type
,
args
,
index
=
None
):
if
request_type
not
in
REQUEST_RETURN_LENGTHS
.
keys
():
raise
NotImplementedError
(
'The request type {} is not implemented!'
.
format
(
request_type
))
raise
NotImplementedError
(
"The request type {} is not implemented!"
.
format
(
request_type
)
)
self
.
request_type
=
request_type
self
.
args
=
args
...
...
@@ -788,17 +907,21 @@ class Request:
def
__iter__
(
self
):
if
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]
is
None
:
raise
IndexError
(
'
This request type does not return multiple arguments!
'
)
raise
IndexError
(
"
This request type does not return multiple arguments!
"
)
for
i
in
range
(
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]):
yield
Request
(
self
.
request_type
,
self
.
args
,
i
)
def
__getitem__
(
self
,
i
):
if
REQUEST_RETURN_LENGTHS
[
self
.
request_type
]
is
None
:
raise
IndexError
(
'
This request type does not return multiple arguments!
'
)
raise
IndexError
(
"
This request type does not return multiple arguments!
"
)
return
Request
(
self
.
request_type
,
self
.
args
,
i
)
def
__eq__
(
self
,
other
):
return
self
.
request_type
==
other
.
request_type
and
self
.
args
==
other
.
args
and
self
.
index
==
other
.
index
return
(
self
.
request_type
==
other
.
request_type
and
self
.
args
==
other
.
args
and
self
.
index
==
other
.
index
)
def
__repr__
(
self
):
return
f
"Req_
{
self
.
request_type
}{
self
.
args
}
[
{
self
.
index
}
]
\n
"
...
...
@@ -808,6 +931,7 @@ class RequestFactory:
def
__getattr__
(
self
,
attr
):
def
fn
(
*
args
):
return
Request
(
attr
,
args
)
return
fn
...
...
lm_eval/evaluator.py
View file @
88745155
...
...
@@ -6,15 +6,27 @@ import lm_eval.metrics
import
lm_eval.models
import
lm_eval.tasks
import
lm_eval.base
import
promptsource
import
numpy
as
np
from
promptsource.templates
import
DatasetTemplates
from
lm_eval.utils
import
positional_deprecated
,
run_task_tests
@
positional_deprecated
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
check_integrity
=
False
):
def
simple_evaluate
(
model
,
model_args
=
None
,
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
check_integrity
=
False
,
):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
...
...
@@ -49,20 +61,26 @@ def simple_evaluate(model, model_args=None, tasks=[],
assert
tasks
!=
[],
"No tasks specified"
if
isinstance
(
model
,
str
):
if
model_args
is
None
:
model_args
=
""
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
'batch_size'
:
batch_size
,
'device'
:
device
})
if
model_args
is
None
:
model_args
=
""
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
"batch_size"
:
batch_size
,
"device"
:
device
}
)
else
:
assert
isinstance
(
model
,
lm_eval
.
base
.
LM
)
lm
=
model
if
not
no_cache
:
lm
=
lm_eval
.
base
.
CachingLM
(
lm
,
'lm_cache/'
+
model
+
'_'
+
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
).
replace
(
'/'
,
'-'
)
+
'.db'
lm
,
"lm_cache/"
+
model
+
"_"
+
model_args
.
replace
(
"="
,
"-"
).
replace
(
","
,
"_"
).
replace
(
"/"
,
"-"
)
+
".db"
,
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
(
tasks
)
task_dict
=
lm_eval
.
tasks
.
get_task_dict
_promptsource
(
tasks
)
if
check_integrity
:
run_task_tests
(
task_list
=
tasks
)
...
...
@@ -72,7 +90,7 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict
=
task_dict
,
num_fewshot
=
num_fewshot
,
limit
=
limit
,
description_dict
=
description_dict
description_dict
=
description_dict
,
)
# add info about the model and few shot config
...
...
@@ -85,14 +103,22 @@ def simple_evaluate(model, model_args=None, tasks=[],
"no_cache"
:
no_cache
,
"limit"
:
limit
,
"bootstrap_iters"
:
bootstrap_iters
,
"description_dict"
:
description_dict
"description_dict"
:
description_dict
,
}
return
results
@
positional_deprecated
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
):
def
evaluate
(
lm
,
task_dict
,
provide_description
=
None
,
num_fewshot
=
0
,
limit
=
None
,
bootstrap_iters
=
100000
,
description_dict
=
None
,
):
"""Instantiate and evaluate a model on a list of tasks.
:param lm: obj
...
...
@@ -118,12 +144,14 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
assert
not
provide_description
# not implemented.
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
task_dict_items
=
[
(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())
]
results
=
collections
.
defaultdict
(
dict
)
...
...
@@ -158,15 +186,16 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
rnd
.
seed
(
42
)
rnd
.
shuffle
(
task_docs
)
description
=
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
description
=
(
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
else
""
)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
docs
[(
task_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
...
...
@@ -189,7 +218,9 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
print
(
"Running"
,
reqtype
,
"requests"
)
resps
=
getattr
(
lm
,
reqtype
)([
req
.
args
for
req
in
reqs
])
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)]
resps
=
[
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
]
for
resp
,
(
i
,
task_name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
process_res_queue
[(
task_name
,
doc_id
)].
append
((
i
,
resp
))
...
...
@@ -208,24 +239,28 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_name
,
metric
)].
append
(
value
)
task_name
,
prompt_name
=
task_name
.
split
(
"+"
)
results
[
task_name
][
"task_name"
]
=
task_name
results
[
task_name
][
"prompt_name"
]
=
prompt_name
# aggregate results
for
(
task_name
,
metric
),
items
in
vals
.
items
():
task
=
task_dict
[
task_name
]
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr
=
lm_eval
.
metrics
.
stderr_for_metric
(
metric
=
task
.
aggregation
()[
metric
],
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
bootstrap_iters
=
min
(
bootstrap_iters
,
1000
)
if
metric
in
[
"bleu"
,
"chrf"
,
"ter"
]
else
bootstrap_iters
,
)
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)
}
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)}
def
make_table
(
result_dict
):
...
...
@@ -247,9 +282,9 @@ def make_table(result_dict):
if
m
+
"_stderr"
in
dic
:
se
=
dic
[
m
+
"_stderr"
]
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
'±'
,
'
%.4f
'
%
se
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
"±"
,
"
%.4f
"
%
se
])
else
:
values
.
append
([
k
,
version
,
m
,
'
%.4f
'
%
v
,
''
,
''
])
values
.
append
([
k
,
version
,
m
,
"
%.4f
"
%
v
,
""
,
""
])
k
=
""
version
=
""
md_writer
.
value_matrix
=
values
...
...
lm_eval/tasks/__init__.py
View file @
88745155
from
promptsource.templates
import
DatasetTemplates
from
pprint
import
pprint
from
typing
import
List
,
Union
...
...
@@ -58,8 +60,8 @@ from . import storycloze
# 6 total
gpt3_translation_benchmarks
=
{
"wmt14"
:
[
'
en-fr
'
,
'
fr-en
'
],
# French
"wmt16"
:
[
'
en-ro
'
,
'
ro-en
'
,
'
de-en
'
,
'
en-de
'
],
# German, Romanian
"wmt14"
:
[
"
en-fr
"
,
"
fr-en
"
],
# French
"wmt16"
:
[
"
en-ro
"
,
"
ro-en
"
,
"
de-en
"
,
"
en-de
"
],
# German, Romanian
}
...
...
@@ -67,7 +69,7 @@ gpt3_translation_benchmarks = {
selected_translation_benchmarks
=
{
**
gpt3_translation_benchmarks
,
"wmt20"
:
sacrebleu
.
get_langpairs_for_testset
(
"wmt20"
),
"iwslt17"
:
[
'
en-ar
'
,
'
ar-en
'
]
# Arabic
"iwslt17"
:
[
"
en-ar
"
,
"
ar-en
"
],
# Arabic
}
# 319 total
...
...
@@ -91,7 +93,7 @@ TASK_REGISTRY = {
"rte"
:
glue
.
RTE
,
"qnli"
:
glue
.
QNLI
,
"qqp"
:
glue
.
QQP
,
#"stsb": glue.STSB, # not implemented yet
#
"stsb": glue.STSB, # not implemented yet
"sst"
:
glue
.
SST
,
"wnli"
:
glue
.
WNLI
,
# SuperGLUE
...
...
@@ -102,34 +104,26 @@ TASK_REGISTRY = {
"record"
:
superglue
.
ReCoRD
,
"wic"
:
superglue
.
WordsInContext
,
"wsc"
:
superglue
.
SGWinogradSchemaChallenge
,
# Order by benchmark/genre?
"coqa"
:
coqa
.
CoQA
,
"drop"
:
drop
.
DROP
,
"lambada"
:
lambada
.
LAMBADA
,
"lambada_cloze"
:
lambada_cloze
.
LAMBADA_cloze
,
# multilingual lambada
**
lambada_multilingual
.
construct_tasks
(),
"wikitext"
:
wikitext
.
WikiText
,
# "cbt-cn": cbt.CBTCN, # disabled pending context length fix
# "cbt-ne": cbt.CBTNE, # disabled pending context length fix
"piqa"
:
piqa
.
PiQA
,
"prost"
:
prost
.
PROST
,
"mc_taco"
:
mc_taco
.
MCTACO
,
# Science related
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
"sciq"
:
sciq
.
SciQ
,
"qasper"
:
qasper
.
QASPER
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"qa4mre_2011"
:
qa4mre
.
QA4MRE_2011
,
"qa4mre_2012"
:
qa4mre
.
QA4MRE_2012
,
"qa4mre_2013"
:
qa4mre
.
QA4MRE_2013
,
"triviaqa"
:
triviaqa
.
TriviaQA
,
"arc_easy"
:
arc
.
ARCEasy
,
"arc_challenge"
:
arc
.
ARCChallenge
,
...
...
@@ -150,21 +144,17 @@ TASK_REGISTRY = {
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r2"
:
anli
.
ANLIRound2
,
"anli_r3"
:
anli
.
ANLIRound3
,
"ethics_cm"
:
hendrycks_ethics
.
EthicsCM
,
"ethics_deontology"
:
hendrycks_ethics
.
EthicsDeontology
,
"ethics_justice"
:
hendrycks_ethics
.
EthicsJustice
,
"ethics_utilitarianism_original"
:
hendrycks_ethics
.
EthicsUtilitarianismOriginal
,
"ethics_utilitarianism"
:
hendrycks_ethics
.
EthicsUtilitarianism
,
"ethics_virtue"
:
hendrycks_ethics
.
EthicsVirtue
,
"truthfulqa_mc"
:
truthfulqa
.
TruthfulQAMultipleChoice
,
"truthfulqa_gen"
:
truthfulqa
.
TruthfulQAGeneration
,
# dialogue
"mutual"
:
mutual
.
MuTual
,
"mutual_plus"
:
mutual
.
MuTualPlus
,
# math
"math_algebra"
:
hendrycks_math
.
MathAlgebra
,
"math_counting_and_prob"
:
hendrycks_math
.
MathCountingAndProbability
,
...
...
@@ -175,7 +165,6 @@ TASK_REGISTRY = {
"math_precalc"
:
hendrycks_math
.
MathPrecalculus
,
"math_asdiv"
:
asdiv
.
Asdiv
,
"gsm8k"
:
gsm8k
.
GradeSchoolMath8K
,
# arithmetic
"arithmetic_2da"
:
arithmetic
.
Arithmetic2DPlus
,
"arithmetic_2ds"
:
arithmetic
.
Arithmetic2DMinus
,
...
...
@@ -189,22 +178,18 @@ TASK_REGISTRY = {
"arithmetic_1dc"
:
arithmetic
.
Arithmetic1DComposite
,
# TODO Perhaps make these groups of tasks
# e.g. anli, arithmetic, openai_translations, harness_translations
# hendrycksTest (57 tasks)
**
hendrycks_test
.
create_all_tasks
(),
# e.g. wmt14-fr-en
**
translation
.
create_tasks_from_benchmarks
(
gpt3_translation_benchmarks
),
# chef's selection, mostly wmt20
**
translation
.
create_tasks_from_benchmarks
(
selected_translation_benchmarks
),
# Word Scrambling and Manipulation Tasks
"anagrams1"
:
unscramble
.
Anagrams1
,
"anagrams2"
:
unscramble
.
Anagrams2
,
"cycle_letters"
:
unscramble
.
CycleLetters
,
"random_insertion"
:
unscramble
.
RandomInsertion
,
"reversed_words"
:
unscramble
.
ReversedWords
,
# Pile
"pile_arxiv"
:
pile
.
PileArxiv
,
"pile_books3"
:
pile
.
PileBooks3
,
...
...
@@ -228,7 +213,6 @@ TASK_REGISTRY = {
"pile_ubuntu-irc"
:
pile
.
PileUbuntuIrc
,
"pile_wikipedia"
:
pile
.
PileWikipedia
,
"pile_youtubesubtitles"
:
pile
.
PileYoutubeSubtitles
,
# BLiMP
"blimp_adjunct_island"
:
blimp
.
BlimpAdjunctIsland
,
"blimp_anaphor_gender_agreement"
:
blimp
.
BlimpAnaphorGenderAgreement
,
...
...
@@ -297,7 +281,6 @@ TASK_REGISTRY = {
"blimp_wh_vs_that_no_gap_long_distance"
:
blimp
.
BlimpWhVsThatNoGapLongDistance
,
"blimp_wh_vs_that_with_gap"
:
blimp
.
BlimpWhVsThatWithGap
,
"blimp_wh_vs_that_with_gap_long_distance"
:
blimp
.
BlimpWhVsThatWithGapLongDistance
,
# Requires manual download of data.
# "storycloze_2016": storycloze.StoryCloze2016,
# "storycloze_2018": storycloze.StoryCloze2018,
...
...
@@ -323,17 +306,41 @@ def get_task_name_from_object(task_object):
return
name
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
return
(
task_object
.
EVAL_HARNESS_NAME
if
hasattr
(
task_object
,
"EVAL_HARNESS_NAME"
)
else
type
(
task_object
).
__name__
)
def
get_task_dict
(
task_name_list
:
List
[
Union
[
str
,
lm_eval
.
base
.
Task
]]):
task_name_dict
=
{
task_name
:
get_task
(
task_name
)()
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
for
task_name
in
task_name_list
if
isinstance
(
task_name
,
str
)
}
task_name_from_object_dict
=
{
get_task_name_from_object
(
task_object
):
task_object
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
for
task_object
in
task_name_list
if
not
isinstance
(
task_object
,
str
)
}
assert
set
(
task_name_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
()))
return
{
**
task_name_dict
,
**
task_name_from_object_dict
}
def
get_task_dict_promptsource
(
task_name_list
:
List
[
str
]):
"""Loads a task instance for each prompt written for that task."""
task_name_dict
=
{}
for
task_name
in
task_name_list
:
assert
isinstance
(
task_name
,
str
)
task_prompts
=
DatasetTemplates
(
task_name
)
for
prompt_name
in
task_prompts
.
all_template_names
:
prompt
=
task_prompts
[
prompt_name
]
# NOTE: We choose a sep that can be easily split.
task_name_dict
[
f
"
{
task_name
}
+
{
prompt_name
}
"
]
=
get_task
(
task_name
)(
prompt
=
prompt
)
return
task_name_dict
lm_eval/tasks/coqa.py
View file @
88745155
...
...
@@ -51,44 +51,22 @@ class CoQA(Task):
def
test_docs
(
self
):
pass
def
doc_to_text
(
self
,
doc
):
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text
=
doc
[
"story"
]
+
'
\n\n
'
for
(
q
,
a
)
in
zip_longest
(
doc
[
"questions"
][
"input_text"
],
doc
[
"answers"
][
"input_text"
][:
-
1
]):
# omit target answer ai
question
=
f
"Q:
{
q
}
\n\n
"
answer
=
f
"A:
{
a
}
\n\n
"
if
a
is
not
None
else
"A:"
doc_text
+=
question
+
answer
return
doc_text
@
classmethod
def
get_answers
(
cls
,
doc
,
turn_id
):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers
=
[]
answer_forturn
=
doc
[
"answers"
][
"input_text"
][
turn_id
-
1
]
answers
.
append
(
answer_forturn
)
additional_answers
=
doc
.
get
(
"additional_answers"
)
if
additional_answers
:
for
key
in
additional_answers
:
additional_answer_for_turn
=
additional_answers
[
key
][
"input_text"
][
turn_id
-
1
]
if
additional_answer_for_turn
.
lower
()
not
in
map
(
str
.
lower
,
answers
):
answers
.
append
(
additional_answer_for_turn
)
return
answers
@
classmethod
def
get_answer_choice
(
self
,
raw_text
):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if
raw_text
==
"unknown"
:
return
'0'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"yes"
:
return
'1'
if
squad_metrics
.
normalize_answer
(
raw_text
)
==
"no"
:
return
'2'
return
'3'
# Not a yes/no question
# @classmethod
# def get_answers(cls, doc, turn_id):
# # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
# answers = []
# answer_forturn = doc["answers"]["input_text"][turn_id - 1]
# answers.append(answer_forturn)
# additional_answers = doc.get("additional_answers")
# if additional_answers:
# for key in additional_answers:
# additional_answer_for_turn = additional_answers[key]["input_text"][
# turn_id - 1
# ]
# if additional_answer_for_turn.lower() not in map(str.lower, answers):
# answers.append(additional_answer_for_turn)
# return answers
@
staticmethod
def
compute_scores
(
gold_list
,
pred
):
...
...
@@ -98,25 +76,23 @@ class CoQA(Task):
em_sum
=
0.0
if
len
(
gold_list
)
>
1
:
for
i
in
range
(
len
(
gold_list
)):
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
gold_answers
=
gold_list
[
0
:
i
]
+
gold_list
[
i
+
1
:]
# predictions compared against (n) golds and take maximum
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_answers
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_answers
)
else
:
em_sum
+=
max
(
squad_metrics
.
compute_exact
(
a
,
pred
)
for
a
in
gold_list
)
f1_sum
+=
max
(
squad_metrics
.
compute_f1
(
a
,
pred
)
for
a
in
gold_list
)
return
{
'em'
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
'f1'
:
f1_sum
/
max
(
1
,
len
(
gold_list
))}
def
doc_to_target
(
self
,
doc
,
turnid
=
None
):
# Default to prediction of last turn.
if
turnid
is
None
:
turnid
=
len
(
doc
[
"questions"
][
"input_text"
])
raw_text
=
doc
[
'answers'
][
"input_text"
][
turnid
-
1
]
return
" "
+
raw_text
return
{
"em"
:
em_sum
/
max
(
1
,
len
(
gold_list
)),
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
}
def
construct_requests
(
self
,
doc
,
ctx
):
"""
Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
...
...
@@ -126,7 +102,7 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
cont_request
=
rf
.
greedy_until
(
ctx
,
[
'
\n
Q:
'
])
cont_request
=
rf
.
greedy_until
(
ctx
,
[
"
\n
Q:
"
])
return
cont_request
def
process_results
(
self
,
doc
,
results
):
...
...
@@ -139,15 +115,18 @@ class CoQA(Task):
:param results:
The results of the requests created in construct_requests.
"""
turn_id
=
len
(
doc
[
"questions"
][
"input_text"
])
gold_list
=
self
.
get_answers
(
doc
,
turn_id
)
pred
=
results
[
0
].
strip
().
split
(
'
\n
'
)[
0
]
target
=
self
.
doc_to_target
(
doc
).
strip
()
pred
=
results
[
0
].
strip
().
split
(
"
\n
"
)[
0
]
# turn_id = len(doc["questions"]["input_text"])
# gold_list = self.get_answers(doc, turn_id)
scores
=
self
.
compute_scores
(
gold_list
,
pred
)
# TODO: Add HF metrics mapped from promptsource metadata.
scores
=
self
.
compute_scores
([
target
],
pred
)
return
{
"f1"
:
scores
[
'
f1
'
],
"em"
:
scores
[
'
em
'
],
"f1"
:
scores
[
"
f1
"
],
"em"
:
scores
[
"
em
"
],
}
def
higher_is_better
(
self
):
...
...
lm_eval/tasks/drop.py
View file @
88745155
...
...
@@ -70,21 +70,26 @@ class DROP(Task):
@
classmethod
def
get_answers
(
cls
,
qa
):
def
_flatten_validated_answers
(
validated_answers
):
"""
Flattens a dict of lists of validated answers.
"""Flattens a dict of lists of validated answers.
{"number": ['1', '8'], ...}
-> [{"number": ['1'], ...}, {"number": ['8'], ...}]
"""
vas
=
[]
for
i
in
range
(
len
(
validated_answers
[
"number"
])):
vas
.
append
({
vas
.
append
(
{
"number"
:
validated_answers
[
"number"
][
i
],
"date"
:
validated_answers
[
"date"
][
i
],
"spans"
:
validated_answers
[
"spans"
][
i
],
})
}
)
return
vas
answers
=
[]
answers_set
=
set
()
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
])
candidates
=
[
qa
[
"answer"
]]
+
_flatten_validated_answers
(
qa
[
"validated_answers"
]
)
for
candidate
in
candidates
:
answer
=
cls
.
parse_answer
(
candidate
)
if
answer
in
answers_set
:
...
...
@@ -100,15 +105,17 @@ class DROP(Task):
return
(
str
(
answer
[
"number"
]),)
if
answer
[
"spans"
]
!=
[]:
return
tuple
(
answer
[
"spans"
])
return
(
" "
.
join
([
answer
[
"date"
][
"day"
],
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]).
strip
(),)
return
(
" "
.
join
(
[
answer
[
"date"
][
"day"
],
answer
[
"date"
][
"month"
],
answer
[
"date"
][
"year"
]]
).
strip
(),
)
def
doc_to_text
(
self
,
doc
):
return
f
"Passage:
{
doc
[
'passage'
]
}
\n
Question:
{
doc
[
'question'
]
}
\n
Answer:"
#
def doc_to_text(self, doc):
#
return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
def
doc_to_target
(
self
,
doc
):
return
" "
+
", "
.
join
(
doc
[
"answers"
][
0
])
#
def doc_to_target(self, doc):
#
return " " + ", ".join(doc["answers"][0])
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
...
...
@@ -134,7 +141,13 @@ class DROP(Task):
:param results:
The results of the requests created in construct_requests.
"""
preds
,
golds
=
results
,
doc
[
"answers"
]
pred
=
results
[
0
].
strip
()
target
=
self
.
doc_to_target
(
doc
).
strip
()
preds
=
[
pred
]
golds
=
[
target
]
max_em
=
0
max_f1
=
0
for
gold_answer
in
golds
:
...
...
@@ -142,10 +155,7 @@ class DROP(Task):
if
gold_answer
[
0
].
strip
():
max_em
=
max
(
max_em
,
exact_match
)
max_f1
=
max
(
max_f1
,
f1_score
)
return
{
"em"
:
max_em
,
"f1"
:
max_f1
}
return
{
"em"
:
max_em
,
"f1"
:
max_f1
}
def
get_metrics
(
self
,
predicted
,
gold
):
"""
...
...
@@ -158,7 +168,9 @@ class DROP(Task):
predicted_bags
=
self
.
_answer_to_bags
(
predicted
)
gold_bags
=
self
.
_answer_to_bags
(
gold
)
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
])
==
len
(
gold_bags
[
0
]):
if
set
(
predicted_bags
[
0
])
==
set
(
gold_bags
[
0
])
and
len
(
predicted_bags
[
0
]
)
==
len
(
gold_bags
[
0
]):
exact_match
=
1.0
else
:
exact_match
=
0.0
...
...
@@ -190,7 +202,9 @@ class DROP(Task):
for
gold_index
,
gold_item
in
enumerate
(
gold
):
for
pred_index
,
pred_item
in
enumerate
(
predicted
):
if
self
.
_match_numbers_if_present
(
gold_item
,
pred_item
):
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
scores
[
gold_index
,
pred_index
]
=
self
.
_compute_f1
(
pred_item
,
gold_item
)
row_ind
,
col_ind
=
linear_sum_assignment
(
-
scores
)
max_scores
=
np
.
zeros
([
max
(
len
(
gold
),
len
(
predicted
))])
...
...
@@ -256,7 +270,11 @@ class DROP(Task):
def
_normalize
(
self
,
answer
):
tokens
=
[
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))))
self
.
_white_space_fix
(
self
.
_remove_articles
(
self
.
_fix_number
(
self
.
_remove_punc
(
token
.
lower
()))
)
)
for
token
in
self
.
_tokenize
(
answer
)
]
tokens
=
[
token
for
token
in
tokens
if
token
.
strip
()]
...
...
@@ -269,10 +287,7 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return
{
"em"
:
mean
,
"f1"
:
mean
}
return
{
"em"
:
mean
,
"f1"
:
mean
}
def
higher_is_better
(
self
):
"""
...
...
@@ -280,7 +295,4 @@ class DROP(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"em"
:
True
,
"f1"
:
True
}
return
{
"em"
:
True
,
"f1"
:
True
}
lm_eval/tasks/race.py
View file @
88745155
...
...
@@ -40,7 +40,7 @@ class RACE(Task):
DATASET_NAME
=
"high"
cache
=
{}
letter_to_num
=
{
'A'
:
0
,
'B'
:
1
,
'C'
:
2
,
'D'
:
3
}
letter_to_num
=
{
"A"
:
0
,
"B"
:
1
,
"C"
:
2
,
"D"
:
3
}
def
has_training_docs
(
self
):
return
True
...
...
@@ -59,17 +59,27 @@ class RACE(Task):
# is shown that one document is made per passage.
r
=
collections
.
defaultdict
(
list
)
for
item
in
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)[
set
]:
r
[
item
[
'article'
]].
append
(
item
)
res
=
list
(
r
.
values
()
>>
each
(
lambda
x
:
{
'article'
:
x
[
0
][
'article'
],
'problems'
:
x
>>
each
(
lambda
y
:
{
'question'
:
y
[
'question'
],
'answer'
:
y
[
'answer'
],
'options'
:
y
[
'options'
],
})
}))
for
item
in
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
)[
set
]:
r
[
item
[
"article"
]].
append
(
item
)
res
=
list
(
r
.
values
()
>>
each
(
lambda
x
:
{
"article"
:
x
[
0
][
"article"
],
"problems"
:
x
>>
each
(
lambda
y
:
{
"question"
:
y
[
"question"
],
"answer"
:
y
[
"answer"
],
"options"
:
y
[
"options"
],
}
),
}
)
)
self
.
cache
[
set
]
=
res
return
res
...
...
@@ -85,45 +95,44 @@ class RACE(Task):
@
classmethod
def
get_answer_option
(
cls
,
problem
):
answer
=
cls
.
letter_to_num
[
problem
[
'
answer
'
]]
return
problem
[
'
options
'
][
answer
]
answer
=
cls
.
letter_to_num
[
problem
[
"
answer
"
]]
return
problem
[
"
options
"
][
answer
]
@
classmethod
def
last_problem
(
cls
,
doc
):
return
doc
[
'problems'
][
-
1
]
def
doc_to_text
(
self
,
doc
):
text
=
'Article: '
+
doc
[
'article'
]
+
'
\n\n
'
for
problem
in
doc
[
'problems'
][:
-
1
]:
if
problem
[
'question'
][
-
6
:]
==
' _ .'
:
text
+=
problem
[
'question'
][
-
5
:]
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
else
:
question
=
'Question: '
+
problem
[
'question'
]
+
'
\n
'
answer
=
'Answer: '
+
self
.
get_answer_option
(
problem
)
+
'
\n
'
text
+=
question
+
answer
text
+=
self
.
last_problem
(
doc
)[
'question'
]
return
text
def
doc_to_target
(
self
,
doc
):
return
" "
+
self
.
get_answer_option
(
self
.
last_problem
(
doc
))
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
problem
=
self
.
last_problem
(
doc
)
ll_choices
=
[
rf
.
loglikelihood
(
ctx
,
" "
+
problem
[
'options'
][
i
])[
0
]
for
i
in
range
(
4
)
]
return
ll_choices
return
doc
[
"problems"
][
-
1
]
# def doc_to_text(self, doc):
# text = 'Article: ' + doc['article'] + '\n\n'
# for problem in doc['problems'][:-1]:
# if problem['question'][-6:] == ' _ .':
# text += problem['question'][-5:] + self.get_answer_option(problem) + '\n'
# else:
# question = 'Question: ' + problem['question'] + '\n'
# answer = 'Answer: ' + self.get_answer_option(problem) + '\n'
# text += question + answer
# text += self.last_problem(doc)['question']
# return text
# def doc_to_target(self, doc):
# return " " + self.get_answer_option(self.last_problem(doc))
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
# Requests which will be sent to the LM.
# :param doc:
# The document as returned from training_docs, validation_docs, or test_docs.
# :param ctx: str
# The context string, generated by fewshot_context. This includes the natural
# language description, as well as the few shot examples, and the question
# part of the document for `doc`.
# """
# problem = self.last_problem(doc)
# ll_choices = [
# rf.loglikelihood(ctx, " " + problem["options"][i])[0] for i in range(4)
# ]
# return ll_choices
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
...
...
@@ -135,11 +144,11 @@ class RACE(Task):
:param results:
The results of the requests created in construct_requests.
"""
gold
=
self
.
letter_to_num
[
self
.
last_problem
(
doc
)[
'answer'
]]
#
gold
=
self
.
letter_to_num
[
self
.
doc_to_target
(
doc
)]
# gold = self.letter_to_num[self.last_problem(doc)["answer"]]
pred
=
np
.
argmax
(
results
)
return
{
"acc"
:
int
(
pred
==
gold
)
}
return
{
"acc"
:
int
(
pred
==
gold
)}
def
aggregation
(
self
):
"""
...
...
@@ -147,9 +156,7 @@ class RACE(Task):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
return
{
"acc"
:
mean
}
return
{
"acc"
:
mean
}
def
higher_is_better
(
self
):
"""
...
...
@@ -157,6 +164,4 @@ class RACE(Task):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return
{
"acc"
:
True
}
return
{
"acc"
:
True
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment