Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
487f7811
Commit
487f7811
authored
May 02, 2023
by
haileyschoelkopf
Browse files
prelim. multiple choice support
parent
e7f49cca
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
140 additions
and
63 deletions
+140
-63
lm_eval/api/__init__.py
lm_eval/api/__init__.py
+3
-0
lm_eval/api/instance.py
lm_eval/api/instance.py
+1
-0
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+6
-1
lm_eval/api/model.py
lm_eval/api/model.py
+14
-10
lm_eval/api/task.py
lm_eval/api/task.py
+102
-44
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-1
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+1
-1
lm_eval/tasks/arc.yaml
lm_eval/tasks/arc.yaml
+8
-3
lm_eval/tasks/lambada.yaml
lm_eval/tasks/lambada.yaml
+3
-3
No files found.
lm_eval/api/__init__.py
View file @
487f7811
...
...
@@ -22,4 +22,7 @@ HIGHER_IS_BETTER_REGISTRY = {
"bleu"
:
True
,
"chrf"
:
True
,
"ter"
:
False
,
"acc"
:
True
,
"acc_norm"
:
True
,
}
\ No newline at end of file
lm_eval/api/instance.py
View file @
487f7811
...
...
@@ -11,6 +11,7 @@ class Instance:
resps
:
list
=
field
(
default_factory
=
list
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
)
# initialized after init
task_name
:
str
=
None
doc_id
:
str
=
None
repeats
:
str
=
None
...
...
lm_eval/api/metrics.py
View file @
487f7811
...
...
@@ -10,7 +10,10 @@ import evaluate
AGGREGATION_REGISTRY
=
{}
METRIC_REGISTRY
=
{}
METRIC_REGISTRY
=
{
"acc"
:
None
,
"acc_norm"
:
None
,
}
def
register_metric
(
name
):
...
...
@@ -45,6 +48,7 @@ searching in HF Evaluate library...")
def
register_aggregation
(
name
):
# TODO: should we enforce a specific interface to aggregation metrics?
def
decorate
(
fn
):
assert
(
name
not
in
AGGREGATION_REGISTRY
...
...
@@ -155,6 +159,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
@
register_metric
(
"perplexity"
)
@
register_aggregation
(
"perplexity"
)
def
perplexity
(
items
):
return
math
.
exp
(
-
mean
(
items
))
...
...
lm_eval/api/model.py
View file @
487f7811
import
abc
from
typing
import
Union
from
lm_eval
import
utils
MODEL_REGISTRY
=
{}
def
register_model
(
name
):
# TODO: should fairseq/elk be cited for this design pattern?
def
register_model
(
*
names
):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def
decorate
(
cls
):
assert
(
issubclass
(
cls
,
LM
)
),
f
"Model '
{
name
}
' (
{
cls
.
__name__
}
) must extend LM class"
for
name
in
names
:
assert
(
issubclass
(
cls
,
LM
)
),
f
"Model '
{
name
}
' (
{
cls
.
__name__
}
) must extend LM class"
assert
(
name
not
in
MODEL_REGISTRY
),
f
"Model named '
{
name
}
' conflicts with existing model!"
assert
(
name
not
in
MODEL_REGISTRY
),
f
"Model named '
{
name
}
' conflicts with existing model!
Please register with a non-conflicting alias instead.
"
MODEL_REGISTRY
[
name
]
=
cls
MODEL_REGISTRY
[
name
]
=
cls
return
cls
return
decorate
...
...
lm_eval/api/task.py
View file @
487f7811
...
...
@@ -5,13 +5,15 @@ import re
import
evaluate
import
random
import
itertools
import
functools
import
datasets
import
numpy
as
np
from
typing
import
List
,
Union
from
lm_eval.api
import
METRIC_REGISTRY
,
AGGREGATION_REGISTRY
,
HIGHER_IS_BETTER_REGISTRY
from
lm_eval.api.metrics
import
METRIC_REGISTRY
,
AGGREGATION_REGISTRY
from
lm_eval.api
import
HIGHER_IS_BETTER_REGISTRY
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.metrics
import
get_metric
,
get_aggregation
,
mean
,
weighted_perplexity
,
bits_per_byte
from
lm_eval
import
utils
...
...
@@ -36,10 +38,11 @@ class TaskConfig(dict):
doc_to_text
:
str
=
""
doc_to_target
:
str
=
""
# aggregation: dict = None # TODO: remove, I think these 2 are obsolete w/ current metric_list impl.
# higher_is_better: dict = None
num_fewshot
:
int
=
0
batch_size
:
int
=
1
repeats
:
int
=
1
metric_list
:
str
=
None
gold_alias
:
str
=
None
output_type
:
str
=
"greedy_until"
...
...
@@ -122,7 +125,8 @@ class Task(abc.ABC):
filter_pipeline
=
build_filter_ensemble
(
name
,
components
)
self
.
_filters
.
append
(
filter_pipeline
)
self
.
sampler
=
samplers
.
Sampler
(
self
.
training_docs
(),
self
,
rnd
=
random
.
Random
())
# TODO: pass the correct docs in here
self
.
sampler
=
samplers
.
Sampler
(
self
.
fewshot_docs
(),
self
,
rnd
=
random
.
Random
())
# TODO: pass the correct docs in here
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
"""Downloads and returns the task dataset.
...
...
@@ -193,6 +197,19 @@ class Task(abc.ABC):
"""
return
[]
def
fewshot_docs
(
self
):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
if
self
.
has_training_docs
():
return
self
.
training_docs
()
elif
self
.
has_validation_docs
():
return
self
.
validation_docs
()
else
:
# TODO: should we allow this case to occur? / should raise a warning here
return
self
.
test_docs
()
def
_process_doc
(
self
,
doc
):
"""
Override this to process (detokenize, strip, replace, etc.) individual
...
...
@@ -336,33 +353,33 @@ class Task(abc.ABC):
labeled_examples
=
""
else
:
#
labeled_examples = self.sampler.get_context(doc, self._config.num_fewshot)
labeled_examples
=
self
.
sampler
.
get_context
(
doc
,
self
.
_config
.
num_fewshot
)
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if
self
.
has_training_docs
():
fewshotex
=
self
.
fewshot_examples
(
k
=
num_fewshot
,
rnd
=
rnd
)
else
:
if
self
.
_fewshot_docs
is
None
:
self
.
_fewshot_docs
=
list
(
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
)
fewshotex
=
rnd
.
sample
(
self
.
_fewshot_docs
,
num_fewshot
+
1
)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
labeled_examples
=
(
"
\n\n
"
.
join
(
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
)
+
"
\n\n
"
)
#
if self.has_training_docs():
#
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
#
else:
#
if self._fewshot_docs is None:
#
self._fewshot_docs = list(
#
self.validation_docs()
#
if self.has_validation_docs()
#
else self.test_docs()
#
)
#
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
#
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
#
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
#
labeled_examples = (
#
"\n\n".join(
#
[
#
self.doc_to_text(doc) + self.doc_to_target(doc)
#
for doc in fewshotex
#
]
#
)
#
+ "\n\n"
#
)
example
=
self
.
doc_to_text
(
doc
)
return
labeled_examples
+
example
...
...
@@ -376,7 +393,7 @@ class Task(abc.ABC):
class
ConfigurableTask
(
Task
):
VERSION
=
"2.0"
OUTPUT_TYPE
=
"greedy_until"
OUTPUT_TYPE
=
None
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
config
:
dict
=
None
...
...
@@ -432,6 +449,8 @@ class ConfigurableTask(Task):
for
name
,
components
in
self
.
_config
.
get
(
"filters"
,
[[
"none"
,
[
"take_first"
]]]):
filter_pipeline
=
build_filter_ensemble
(
name
,
components
)
self
.
_filters
.
append
(
filter_pipeline
)
self
.
sampler
=
samplers
.
Sampler
(
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
())
# TODO: pass the correct docs in here
def
has_training_docs
(
self
):
if
self
.
_config
.
training_split
is
not
None
:
...
...
@@ -463,6 +482,13 @@ class ConfigurableTask(Task):
if
self
.
_config
.
test_split
is
not
None
:
return
self
.
dataset
[
self
.
_config
.
test_split
]
def
fewshot_docs
(
self
):
if
self
.
_config
.
fewshot_split
:
return
self
.
dataset
[
self
.
_config
.
fewshot_split
]
else
:
# TODO: warn user if fewshot split isn't explicitly set
return
super
().
fewshot_docs
()
def
should_decontaminate
(
self
):
return
self
.
_config
.
should_decontaminate
...
...
@@ -497,6 +523,19 @@ class ConfigurableTask(Task):
arguments
=
(
ctx
,
self
.
doc_to_target
(
doc
))
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
arguments
=
(
self
.
doc_to_target
(
doc
),)
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
import
ast
return
[
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
arguments
=
(
ctx
,
" {}"
.
format
(
choice
)),
id_
=
i
,
**
kwargs
,
)
for
i
,
choice
in
enumerate
(
ast
.
literal_eval
(
utils
.
apply_template
(
self
.
_config
.
template_aliases
+
"{{answer_choices}}"
,
doc
)))
# we pass the user-defined answer_choices var (in aliases) and echo the result. TODO: any cleaner way to do this?
]
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
arguments
=
(
ctx
,
self
.
_config
.
delimiter
)
...
...
@@ -504,6 +543,7 @@ class ConfigurableTask(Task):
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
arguments
=
arguments
,
id_
=
0
,
**
kwargs
)
...
...
@@ -516,6 +556,22 @@ class ConfigurableTask(Task):
result_dict
=
{
"perplexity"
:
ll
,
"accuracy"
:
int
(
is_greedy
)}
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
pass
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
lls
=
[
res
[
0
]
for
res
in
results
]
# only retain loglikelihoods, discard is_greedy
gold
=
int
(
self
.
doc_to_target
(
doc
))
# TODO: if `gold` here is an integer/ds label obj, map it to answer_choice
# TODO: remove dependence on "gold" and "choices" columns
acc
=
1.0
if
np
.
argmax
(
lls
)
==
gold
else
0.0
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
doc
[
"choices"
]])
acc_norm
=
1.0
if
np
.
argmax
(
results
/
completion_len
)
==
gold
else
0.0
# TODO: set which normalization metrics should be reported, and calculate them
# TODO: add mutual info.
result_dict
=
{
"acc"
:
acc
,
"acc_norm"
:
acc_norm
,
}
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
if
self
.
_config
.
gold_alias
is
not
None
:
...
...
@@ -531,6 +587,10 @@ class ConfigurableTask(Task):
)
result_dict
[
key
]
=
_dict
[
key
]
else
:
raise
ValueError
(
f
"Passed invalid output_type '
{
self
.
OUTPUT_TYPE
}
' ! Please use one of "
,
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'"
)
return
result_dict
...
...
@@ -558,11 +618,6 @@ class MultipleChoiceTask(Task):
**
kwargs
,
)
for
i
,
choice
in
enumerate
(
doc
[
"choices"
])]
#lls = [
# rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
# ]
# return lls
def
process_results
(
self
,
doc
,
results
):
results
=
[
res
[
0
]
for
res
in
results
]
# only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
...
...
@@ -668,19 +723,22 @@ class PerplexityTask(Task, abc.ABC):
TASK_REGISTRY
=
{}
ALL_TASKS
=
[]
def
register_task
(
name
):
def
register_task
(
*
names
):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def
decorate
(
cls
):
assert
(
issubclass
(
cls
,
Task
)
),
f
"Task '
{
name
}
' (
{
cls
.
__name__
}
) must extend Task class"
for
name
in
names
:
assert
(
issubclass
(
cls
,
Task
)
),
f
"Task '
{
name
}
' (
{
cls
.
__name__
}
) must extend Task class"
assert
(
name
not
in
TASK_REGISTRY
),
f
"Task named '
{
name
}
' conflicts with existing task!"
assert
(
name
not
in
TASK_REGISTRY
),
f
"Task named '
{
name
}
' conflicts with existing task!
Please register with a non-conflicting alias instead.
"
TASK_REGISTRY
[
name
]
=
cls
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
))
# TODO: this doesn't seem to import right.
TASK_REGISTRY
[
name
]
=
cls
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
))
# TODO: this doesn't seem to import right.
return
cls
return
decorate
...
...
lm_eval/evaluator.py
View file @
487f7811
...
...
@@ -145,7 +145,8 @@ def evaluate(
# for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
task
.
build_all_requests
(
limit
=
limit
)
# aggregate Instances by LM method requested to get output.
requests
[
task
.
OUTPUT_TYPE
].
extend
(
task
.
instances
)
reqtype
=
"loglikelihood"
if
task
.
OUTPUT_TYPE
==
"multiple_choice"
else
task
.
OUTPUT_TYPE
#TODO: this is hacky, fix in task.py
requests
[
reqtype
].
extend
(
task
.
instances
)
### Run LM on inputs, get all outputs ###
# execute each type of request
...
...
lm_eval/models/gpt2.py
View file @
487f7811
...
...
@@ -9,7 +9,7 @@ from lm_eval import utils
from
lm_eval.api.model
import
LM
,
register_model
@
register_model
(
"hf-causal"
)
@
register_model
(
"hf-causal"
,
"gpt2"
)
class
HFLM
(
LM
):
def
__init__
(
self
,
...
...
lm_eval/tasks/arc.yaml
View file @
487f7811
dataset_path
:
ai2_arc
dataset_name
:
ARC-Challenge
output_type
:
multiple_choice
training_split
:
train
validation_split
:
validation
test_split
:
test
doc_to_text
:
"
Q:
{{question}}
\n
A:"
doc_to_target
:
"
{%
set
answer_choices
=
doc['choices']['text']
%}{{answer_choices[int(doc['answerKey'])
-
1]}}"
template_aliases
:
"
{%
set
answer_choices
=
choices['text']
%}{%
set
gold
=
choices.label.index(answerKey)
%}"
# set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text
:
"
Question:
{{question}}
\n
Answer:"
doc_to_target
:
"
{{gold}}"
metric_list
:
-
metric
:
exact_match
-
metric
:
acc
aggregation
:
mean
higher_is_better
:
true
-
metric
:
acc_norm
aggregation
:
mean
higher_is_better
:
true
\ No newline at end of file
lm_eval/tasks/lambada.yaml
View file @
487f7811
dataset_path
:
EleutherAI/lambada_openai
dataset_name
:
default
output_type
:
"
loglikelihood
"
output_type
:
loglikelihood
test_split
:
test
template_aliases
:
"
{%
set
hypo
=
hypothesis
%}
"
template_aliases
:
"
"
doc_to_text
:
"
{{text.split('
')[:-1]|join('
')}}"
doc_to_target
:
"
{{'
'+text.split('
')[-1]}}"
should_decontaminate
:
true
...
...
@@ -12,5 +12,5 @@ metric_list:
aggregation
:
perplexity
higher_is_better
:
true
-
metric
:
accuracy
aggregation
:
perplexity
aggregation
:
mean
higher_is_better
:
true
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment