Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
a8396b2c
"magic_pdf/vscode:/vscode.git/clone" did not exist on "c7a685b302ceb161e1e60813457575f42d7a66ab"
Commit
a8396b2c
authored
Jul 02, 2023
by
Benjamin Fattori
Browse files
add new task type: winograd_schema
parent
d674c7bd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
13 deletions
+75
-13
lm_eval/api/registry.py
lm_eval/api/registry.py
+1
-0
lm_eval/api/task.py
lm_eval/api/task.py
+73
-11
lm_eval/evaluator.py
lm_eval/evaluator.py
+1
-2
No files found.
lm_eval/api/registry.py
View file @
a8396b2c
...
@@ -80,6 +80,7 @@ DEFAULT_METRIC_REGISTRY = {
...
@@ -80,6 +80,7 @@ DEFAULT_METRIC_REGISTRY = {
],
],
"loglikelihood_rolling"
:
[
"word_perplexity"
,
"byte_perplexity"
,
"bits_per_byte"
],
"loglikelihood_rolling"
:
[
"word_perplexity"
,
"byte_perplexity"
,
"bits_per_byte"
],
"multiple_choice"
:
[
"acc"
,
"acc_norm"
],
"multiple_choice"
:
[
"acc"
,
"acc_norm"
],
"winograd_schema"
:
[
"acc"
,
"acc_norm"
],
"greedy_until"
:
[
"exact_match"
],
"greedy_until"
:
[
"exact_match"
],
}
}
...
...
lm_eval/api/task.py
View file @
a8396b2c
...
@@ -44,6 +44,7 @@ ALL_OUTPUT_TYPES = [
...
@@ -44,6 +44,7 @@ ALL_OUTPUT_TYPES = [
"multiple_choice"
,
"multiple_choice"
,
"loglikelihood_rolling"
,
"loglikelihood_rolling"
,
"greedy_until"
,
"greedy_until"
,
"winograd_schema"
]
]
...
@@ -75,6 +76,7 @@ class TaskConfig(dict):
...
@@ -75,6 +76,7 @@ class TaskConfig(dict):
metric_list
:
str
=
None
metric_list
:
str
=
None
gold_alias
:
Union
[
Callable
,
str
]
=
None
gold_alias
:
Union
[
Callable
,
str
]
=
None
create_choices
:
Union
[
Callable
,
str
]
=
None
output_type
:
str
=
"greedy_until"
output_type
:
str
=
"greedy_until"
generation_kwargs
:
dict
=
None
generation_kwargs
:
dict
=
None
filter_list
:
Union
[
str
,
list
]
=
None
filter_list
:
Union
[
str
,
list
]
=
None
...
@@ -295,6 +297,16 @@ class Task(abc.ABC):
...
@@ -295,6 +297,16 @@ class Task(abc.ABC):
The processed version of the specified `doc`.
The processed version of the specified `doc`.
"""
"""
return
doc
return
doc
def
create_choices
(
self
,
doc
):
if
self
.
_config
.
create_choices
is
None
:
return
ast
.
literal_eval
(
utils
.
apply_template
(
self
.
_config
.
template_aliases
+
"{{answer_choices}}"
,
doc
)
)
else
:
return
self
.
_config
.
create_choices
(
doc
)
@
property
@
property
def
instances
(
self
):
def
instances
(
self
):
...
@@ -746,11 +758,8 @@ class ConfigurableTask(Task):
...
@@ -746,11 +758,8 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this?
# TODO: any cleaner way to do this?
choices
=
ast
.
literal_eval
(
choices
=
self
.
create_choices
(
doc
)
utils
.
apply_template
(
self
.
_config
.
template_aliases
+
"{{answer_choices}}"
,
doc
)
)
request_list
=
[
request_list
=
[
Instance
(
Instance
(
request_type
=
"loglikelihood"
,
request_type
=
"loglikelihood"
,
...
@@ -786,6 +795,45 @@ class ConfigurableTask(Task):
...
@@ -786,6 +795,45 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
arguments
=
(
ctx
,
self
.
_config
.
generation_kwargs
)
arguments
=
(
ctx
,
self
.
_config
.
generation_kwargs
)
elif
self
.
OUTPUT_TYPE
==
"winograd_schema"
:
# similar to multiple_choice task type except each request contains
# multiple differing contexts with the same continuation
contexts
=
self
.
create_choices
(
doc
)
choice
=
self
.
doc_to_target
(
doc
)
request_list
=
[
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
arguments
=
(
context
,
" {}"
.
format
(
choice
)),
idx
=
i
,
**
kwargs
,
)
for
i
,
context
in
enumerate
(
contexts
)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if
"acc_mutual_info"
in
self
.
_metric_fn_list
.
keys
():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
request_list
.
extend
(
[
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
arguments
=
(
""
,
"{}"
.
format
(
choice
)),
idx
=
i
,
**
kwargs
,
)
for
i
,
choice
in
enumerate
(
choices
)
]
)
return
request_list
return
Instance
(
return
Instance
(
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
arguments
=
arguments
,
idx
=
0
,
**
kwargs
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
arguments
=
arguments
,
idx
=
0
,
**
kwargs
)
)
...
@@ -835,11 +883,7 @@ class ConfigurableTask(Task):
...
@@ -835,11 +883,7 @@ class ConfigurableTask(Task):
pred
=
np
.
argmax
(
lls
)
pred
=
np
.
argmax
(
lls
)
# retrieve choices in List[str] form, to compute choice lengths, etc.
# retrieve choices in List[str] form, to compute choice lengths, etc.
choices
=
ast
.
literal_eval
(
choices
=
self
.
create_choices
(
doc
)
utils
.
apply_template
(
self
.
_config
.
template_aliases
+
"{{answer_choices}}"
,
doc
)
)
if
(
if
(
2
*
len
(
choices
)
==
len
(
lls
)
2
*
len
(
choices
)
==
len
(
lls
)
and
"acc_mutual_info"
in
self
.
_metric_fn_list
.
keys
()
and
"acc_mutual_info"
in
self
.
_metric_fn_list
.
keys
()
...
@@ -875,6 +919,24 @@ class ConfigurableTask(Task):
...
@@ -875,6 +919,24 @@ class ConfigurableTask(Task):
acc_mutual_info
=
1.0
if
np
.
argmax
(
lls_mutual_info
)
==
gold
else
0.0
acc_mutual_info
=
1.0
if
np
.
argmax
(
lls_mutual_info
)
==
gold
else
0.0
result_dict
[
"acc_mutual_info"
]
=
acc_mutual_info
result_dict
[
"acc_mutual_info"
]
=
acc_mutual_info
elif
self
.
OUTPUT_TYPE
==
"winograd_schema"
:
lls
,
is_greedy
=
zip
(
*
results
)
if
self
.
_config
.
gold_alias
is
not
None
:
gold
=
int
(
self
.
gold_alias
(
doc
))
else
:
gold
=
int
(
self
.
doc_to_target
(
doc
))
pred
=
np
.
argmax
(
lls
)
acc
=
1.0
if
np
.
argmax
(
lls
)
==
gold
else
0.0
result_dict
=
{
**
({
"acc"
:
acc
}
if
"acc"
in
use_metric
else
{}),
**
({
"f1"
:
(
gold
,
pred
)}
if
"f1"
in
use_metric
else
{}),
**
({
"mcc"
:
(
gold
,
pred
)}
if
"mcc"
in
use_metric
else
{}),
**
({
"acc_norm"
:
acc_norm
}
if
"acc_norm"
in
use_metric
else
{}),
}
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
if
self
.
_config
.
gold_alias
is
not
None
:
if
self
.
_config
.
gold_alias
is
not
None
:
...
@@ -893,7 +955,7 @@ class ConfigurableTask(Task):
...
@@ -893,7 +955,7 @@ class ConfigurableTask(Task):
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Passed invalid output_type '
{
self
.
OUTPUT_TYPE
}
' ! Please use one of "
,
f
"Passed invalid output_type '
{
self
.
OUTPUT_TYPE
}
' ! Please use one of "
,
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until',
or
'multiple_choice'"
,
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until', 'multiple_choice'
or 'winograd_schema'
"
,
)
)
return
result_dict
return
result_dict
...
...
lm_eval/evaluator.py
View file @
a8396b2c
...
@@ -214,7 +214,7 @@ def evaluate(
...
@@ -214,7 +214,7 @@ def evaluate(
# aggregate Instances by LM method requested to get output.
# aggregate Instances by LM method requested to get output.
reqtype
=
(
reqtype
=
(
"loglikelihood"
"loglikelihood"
if
task
.
OUTPUT_TYPE
==
"multiple_choice"
if
(
task
.
OUTPUT_TYPE
==
"multiple_choice"
or
task
.
OUTPUT_TYPE
==
"winograd_schema"
)
else
task
.
OUTPUT_TYPE
else
task
.
OUTPUT_TYPE
)
# TODO: this is hacky, fix in task.py
)
# TODO: this is hacky, fix in task.py
requests
[
reqtype
].
extend
(
task
.
instances
)
requests
[
reqtype
].
extend
(
task
.
instances
)
...
@@ -274,7 +274,6 @@ def evaluate(
...
@@ -274,7 +274,6 @@ def evaluate(
enumerate
(
task
.
validation_docs
()),
lm
.
rank
,
limit
,
lm
.
world_size
enumerate
(
task
.
validation_docs
()),
lm
.
rank
,
limit
,
lm
.
world_size
)
)
)
)
for
doc_id
,
doc
in
doc_iterator
:
for
doc_id
,
doc
in
doc_iterator
:
# subset instances to only this document id ; sort by idx
# subset instances to only this document id ; sort by idx
requests
=
list
(
filter
(
lambda
x
:
x
.
doc_id
==
doc_id
,
task
.
instances
))
requests
=
list
(
filter
(
lambda
x
:
x
.
doc_id
==
doc_id
,
task
.
instances
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment