Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
487f7811
Commit
487f7811
authored
May 02, 2023
by
haileyschoelkopf
Browse files
prelim. multiple choice support
parent
e7f49cca
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
140 additions
and
63 deletions
+140
-63
lm_eval/api/__init__.py
lm_eval/api/__init__.py
+3
-0
lm_eval/api/instance.py
lm_eval/api/instance.py
+1
-0
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+6
-1
lm_eval/api/model.py
lm_eval/api/model.py
+14
-10
lm_eval/api/task.py
lm_eval/api/task.py
+102
-44
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-1
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+1
-1
lm_eval/tasks/arc.yaml
lm_eval/tasks/arc.yaml
+8
-3
lm_eval/tasks/lambada.yaml
lm_eval/tasks/lambada.yaml
+3
-3
No files found.
lm_eval/api/__init__.py
View file @
487f7811
...
@@ -22,4 +22,7 @@ HIGHER_IS_BETTER_REGISTRY = {
...
@@ -22,4 +22,7 @@ HIGHER_IS_BETTER_REGISTRY = {
"bleu"
:
True
,
"bleu"
:
True
,
"chrf"
:
True
,
"chrf"
:
True
,
"ter"
:
False
,
"ter"
:
False
,
"acc"
:
True
,
"acc_norm"
:
True
,
}
}
\ No newline at end of file
lm_eval/api/instance.py
View file @
487f7811
...
@@ -11,6 +11,7 @@ class Instance:
...
@@ -11,6 +11,7 @@ class Instance:
resps
:
list
=
field
(
default_factory
=
list
)
resps
:
list
=
field
(
default_factory
=
list
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
)
# initialized after init
task_name
:
str
=
None
task_name
:
str
=
None
doc_id
:
str
=
None
doc_id
:
str
=
None
repeats
:
str
=
None
repeats
:
str
=
None
...
...
lm_eval/api/metrics.py
View file @
487f7811
...
@@ -10,7 +10,10 @@ import evaluate
...
@@ -10,7 +10,10 @@ import evaluate
AGGREGATION_REGISTRY
=
{}
AGGREGATION_REGISTRY
=
{}
METRIC_REGISTRY
=
{}
METRIC_REGISTRY
=
{
"acc"
:
None
,
"acc_norm"
:
None
,
}
def
register_metric
(
name
):
def
register_metric
(
name
):
...
@@ -45,6 +48,7 @@ searching in HF Evaluate library...")
...
@@ -45,6 +48,7 @@ searching in HF Evaluate library...")
def
register_aggregation
(
name
):
def
register_aggregation
(
name
):
# TODO: should we enforce a specific interface to aggregation metrics?
def
decorate
(
fn
):
def
decorate
(
fn
):
assert
(
assert
(
name
not
in
AGGREGATION_REGISTRY
name
not
in
AGGREGATION_REGISTRY
...
@@ -155,6 +159,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
...
@@ -155,6 +159,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
@
register_metric
(
"perplexity"
)
@
register_metric
(
"perplexity"
)
@
register_aggregation
(
"perplexity"
)
def
perplexity
(
items
):
def
perplexity
(
items
):
return
math
.
exp
(
-
mean
(
items
))
return
math
.
exp
(
-
mean
(
items
))
...
...
lm_eval/api/model.py
View file @
487f7811
import
abc
import
abc
from
typing
import
Union
from
lm_eval
import
utils
from
lm_eval
import
utils
MODEL_REGISTRY
=
{}
MODEL_REGISTRY
=
{}
def
register_model
(
name
):
def
register_model
(
*
names
):
# TODO: should fairseq/elk be cited for this design pattern?
# either pass a list or a single alias.
# function receives them as a tuple of strings
def
decorate
(
cls
):
def
decorate
(
cls
):
for
name
in
names
:
assert
(
assert
(
issubclass
(
cls
,
LM
)
issubclass
(
cls
,
LM
)
),
f
"Model '
{
name
}
' (
{
cls
.
__name__
}
) must extend LM class"
),
f
"Model '
{
name
}
' (
{
cls
.
__name__
}
) must extend LM class"
assert
(
assert
(
name
not
in
MODEL_REGISTRY
name
not
in
MODEL_REGISTRY
),
f
"Model named '
{
name
}
' conflicts with existing model!"
),
f
"Model named '
{
name
}
' conflicts with existing model!
Please register with a non-conflicting alias instead.
"
MODEL_REGISTRY
[
name
]
=
cls
MODEL_REGISTRY
[
name
]
=
cls
return
cls
return
cls
...
...
lm_eval/api/task.py
View file @
487f7811
...
@@ -5,13 +5,15 @@ import re
...
@@ -5,13 +5,15 @@ import re
import
evaluate
import
evaluate
import
random
import
random
import
itertools
import
itertools
import
functools
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
lm_eval.api
import
METRIC_REGISTRY
,
AGGREGATION_REGISTRY
,
HIGHER_IS_BETTER_REGISTRY
from
lm_eval.api.metrics
import
METRIC_REGISTRY
,
AGGREGATION_REGISTRY
from
lm_eval.api
import
HIGHER_IS_BETTER_REGISTRY
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.metrics
import
get_metric
,
get_aggregation
,
mean
,
weighted_perplexity
,
bits_per_byte
from
lm_eval.api.metrics
import
get_metric
,
get_aggregation
,
mean
,
weighted_perplexity
,
bits_per_byte
from
lm_eval
import
utils
from
lm_eval
import
utils
...
@@ -36,10 +38,11 @@ class TaskConfig(dict):
...
@@ -36,10 +38,11 @@ class TaskConfig(dict):
doc_to_text
:
str
=
""
doc_to_text
:
str
=
""
doc_to_target
:
str
=
""
doc_to_target
:
str
=
""
# aggregation: dict = None # TODO: remove, I think these 2 are obsolete w/ current metric_list impl.
# higher_is_better: dict = None
num_fewshot
:
int
=
0
num_fewshot
:
int
=
0
batch_size
:
int
=
1
batch_size
:
int
=
1
repeats
:
int
=
1
metric_list
:
str
=
None
metric_list
:
str
=
None
gold_alias
:
str
=
None
gold_alias
:
str
=
None
output_type
:
str
=
"greedy_until"
output_type
:
str
=
"greedy_until"
...
@@ -122,7 +125,8 @@ class Task(abc.ABC):
...
@@ -122,7 +125,8 @@ class Task(abc.ABC):
filter_pipeline
=
build_filter_ensemble
(
name
,
components
)
filter_pipeline
=
build_filter_ensemble
(
name
,
components
)
self
.
_filters
.
append
(
filter_pipeline
)
self
.
_filters
.
append
(
filter_pipeline
)
self
.
sampler
=
samplers
.
Sampler
(
self
.
training_docs
(),
self
,
rnd
=
random
.
Random
())
# TODO: pass the correct docs in here
self
.
sampler
=
samplers
.
Sampler
(
self
.
fewshot_docs
(),
self
,
rnd
=
random
.
Random
())
# TODO: pass the correct docs in here
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
def
download
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
):
"""Downloads and returns the task dataset.
"""Downloads and returns the task dataset.
...
@@ -193,6 +197,19 @@ class Task(abc.ABC):
...
@@ -193,6 +197,19 @@ class Task(abc.ABC):
"""
"""
return
[]
return
[]
def
fewshot_docs
(
self
):
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
if
self
.
has_training_docs
():
return
self
.
training_docs
()
elif
self
.
has_validation_docs
():
return
self
.
validation_docs
()
else
:
# TODO: should we allow this case to occur? / should raise a warning here
return
self
.
test_docs
()
def
_process_doc
(
self
,
doc
):
def
_process_doc
(
self
,
doc
):
"""
"""
Override this to process (detokenize, strip, replace, etc.) individual
Override this to process (detokenize, strip, replace, etc.) individual
...
@@ -336,33 +353,33 @@ class Task(abc.ABC):
...
@@ -336,33 +353,33 @@ class Task(abc.ABC):
labeled_examples
=
""
labeled_examples
=
""
else
:
else
:
#
labeled_examples = self.sampler.get_context(doc, self._config.num_fewshot)
labeled_examples
=
self
.
sampler
.
get_context
(
doc
,
self
.
_config
.
num_fewshot
)
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if
self
.
has_training_docs
():
#
if self.has_training_docs():
fewshotex
=
self
.
fewshot_examples
(
k
=
num_fewshot
,
rnd
=
rnd
)
#
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else
:
#
else:
if
self
.
_fewshot_docs
is
None
:
#
if self._fewshot_docs is None:
self
.
_fewshot_docs
=
list
(
#
self._fewshot_docs = list(
self
.
validation_docs
()
#
self.validation_docs()
if
self
.
has_validation_docs
()
#
if self.has_validation_docs()
else
self
.
test_docs
()
#
else self.test_docs()
)
#
)
fewshotex
=
rnd
.
sample
(
self
.
_fewshot_docs
,
num_fewshot
+
1
)
#
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
#
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
#
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples
=
(
#
labeled_examples = (
"
\n\n
"
.
join
(
#
"\n\n".join(
[
#
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
#
self.doc_to_text(doc) + self.doc_to_target(doc)
for
doc
in
fewshotex
#
for doc in fewshotex
]
#
]
)
#
)
+
"
\n\n
"
#
+ "\n\n"
)
#
)
example
=
self
.
doc_to_text
(
doc
)
example
=
self
.
doc_to_text
(
doc
)
return
labeled_examples
+
example
return
labeled_examples
+
example
...
@@ -376,7 +393,7 @@ class Task(abc.ABC):
...
@@ -376,7 +393,7 @@ class Task(abc.ABC):
class
ConfigurableTask
(
Task
):
class
ConfigurableTask
(
Task
):
VERSION
=
"2.0"
VERSION
=
"2.0"
OUTPUT_TYPE
=
"greedy_until"
OUTPUT_TYPE
=
None
def
__init__
(
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
config
:
dict
=
None
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
config
:
dict
=
None
...
@@ -433,6 +450,8 @@ class ConfigurableTask(Task):
...
@@ -433,6 +450,8 @@ class ConfigurableTask(Task):
filter_pipeline
=
build_filter_ensemble
(
name
,
components
)
filter_pipeline
=
build_filter_ensemble
(
name
,
components
)
self
.
_filters
.
append
(
filter_pipeline
)
self
.
_filters
.
append
(
filter_pipeline
)
self
.
sampler
=
samplers
.
Sampler
(
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
random
.
Random
())
# TODO: pass the correct docs in here
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
if
self
.
_config
.
training_split
is
not
None
:
if
self
.
_config
.
training_split
is
not
None
:
return
True
return
True
...
@@ -463,6 +482,13 @@ class ConfigurableTask(Task):
...
@@ -463,6 +482,13 @@ class ConfigurableTask(Task):
if
self
.
_config
.
test_split
is
not
None
:
if
self
.
_config
.
test_split
is
not
None
:
return
self
.
dataset
[
self
.
_config
.
test_split
]
return
self
.
dataset
[
self
.
_config
.
test_split
]
def
fewshot_docs
(
self
):
if
self
.
_config
.
fewshot_split
:
return
self
.
dataset
[
self
.
_config
.
fewshot_split
]
else
:
# TODO: warn user if fewshot split isn't explicitly set
return
super
().
fewshot_docs
()
def
should_decontaminate
(
self
):
def
should_decontaminate
(
self
):
return
self
.
_config
.
should_decontaminate
return
self
.
_config
.
should_decontaminate
...
@@ -497,6 +523,19 @@ class ConfigurableTask(Task):
...
@@ -497,6 +523,19 @@ class ConfigurableTask(Task):
arguments
=
(
ctx
,
self
.
doc_to_target
(
doc
))
arguments
=
(
ctx
,
self
.
doc_to_target
(
doc
))
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
arguments
=
(
self
.
doc_to_target
(
doc
),)
arguments
=
(
self
.
doc_to_target
(
doc
),)
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
import
ast
return
[
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
arguments
=
(
ctx
,
" {}"
.
format
(
choice
)),
id_
=
i
,
**
kwargs
,
)
for
i
,
choice
in
enumerate
(
ast
.
literal_eval
(
utils
.
apply_template
(
self
.
_config
.
template_aliases
+
"{{answer_choices}}"
,
doc
)))
# we pass the user-defined answer_choices var (in aliases) and echo the result. TODO: any cleaner way to do this?
]
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
arguments
=
(
ctx
,
self
.
_config
.
delimiter
)
arguments
=
(
ctx
,
self
.
_config
.
delimiter
)
...
@@ -504,6 +543,7 @@ class ConfigurableTask(Task):
...
@@ -504,6 +543,7 @@ class ConfigurableTask(Task):
request_type
=
self
.
OUTPUT_TYPE
,
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
doc
=
doc
,
arguments
=
arguments
,
arguments
=
arguments
,
id_
=
0
,
**
kwargs
**
kwargs
)
)
...
@@ -516,6 +556,22 @@ class ConfigurableTask(Task):
...
@@ -516,6 +556,22 @@ class ConfigurableTask(Task):
result_dict
=
{
"perplexity"
:
ll
,
"accuracy"
:
int
(
is_greedy
)}
result_dict
=
{
"perplexity"
:
ll
,
"accuracy"
:
int
(
is_greedy
)}
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
elif
self
.
OUTPUT_TYPE
==
"loglikelihood_rolling"
:
pass
pass
elif
self
.
OUTPUT_TYPE
==
"multiple_choice"
:
lls
=
[
res
[
0
]
for
res
in
results
]
# only retain loglikelihoods, discard is_greedy
gold
=
int
(
self
.
doc_to_target
(
doc
))
# TODO: if `gold` here is an integer/ds label obj, map it to answer_choice
# TODO: remove dependence on "gold" and "choices" columns
acc
=
1.0
if
np
.
argmax
(
lls
)
==
gold
else
0.0
completion_len
=
np
.
array
([
float
(
len
(
i
))
for
i
in
doc
[
"choices"
]])
acc_norm
=
1.0
if
np
.
argmax
(
results
/
completion_len
)
==
gold
else
0.0
# TODO: set which normalization metrics should be reported, and calculate them
# TODO: add mutual info.
result_dict
=
{
"acc"
:
acc
,
"acc_norm"
:
acc_norm
,
}
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
elif
self
.
OUTPUT_TYPE
==
"greedy_until"
:
if
self
.
_config
.
gold_alias
is
not
None
:
if
self
.
_config
.
gold_alias
is
not
None
:
...
@@ -531,6 +587,10 @@ class ConfigurableTask(Task):
...
@@ -531,6 +587,10 @@ class ConfigurableTask(Task):
)
)
result_dict
[
key
]
=
_dict
[
key
]
result_dict
[
key
]
=
_dict
[
key
]
else
:
raise
ValueError
(
f
"Passed invalid output_type '
{
self
.
OUTPUT_TYPE
}
' ! Please use one of "
,
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'"
)
return
result_dict
return
result_dict
...
@@ -558,11 +618,6 @@ class MultipleChoiceTask(Task):
...
@@ -558,11 +618,6 @@ class MultipleChoiceTask(Task):
**
kwargs
,
**
kwargs
,
)
)
for
i
,
choice
in
enumerate
(
doc
[
"choices"
])]
for
i
,
choice
in
enumerate
(
doc
[
"choices"
])]
#lls = [
# rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
# ]
# return lls
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
results
=
[
res
[
0
]
for
res
in
results
]
# only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
results
=
[
res
[
0
]
for
res
in
results
]
# only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
...
@@ -668,16 +723,19 @@ class PerplexityTask(Task, abc.ABC):
...
@@ -668,16 +723,19 @@ class PerplexityTask(Task, abc.ABC):
TASK_REGISTRY
=
{}
TASK_REGISTRY
=
{}
ALL_TASKS
=
[]
ALL_TASKS
=
[]
def
register_task
(
name
):
def
register_task
(
*
names
):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def
decorate
(
cls
):
def
decorate
(
cls
):
for
name
in
names
:
assert
(
assert
(
issubclass
(
cls
,
Task
)
issubclass
(
cls
,
Task
)
),
f
"Task '
{
name
}
' (
{
cls
.
__name__
}
) must extend Task class"
),
f
"Task '
{
name
}
' (
{
cls
.
__name__
}
) must extend Task class"
assert
(
assert
(
name
not
in
TASK_REGISTRY
name
not
in
TASK_REGISTRY
),
f
"Task named '
{
name
}
' conflicts with existing task!"
),
f
"Task named '
{
name
}
' conflicts with existing task!
Please register with a non-conflicting alias instead.
"
TASK_REGISTRY
[
name
]
=
cls
TASK_REGISTRY
[
name
]
=
cls
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
))
# TODO: this doesn't seem to import right.
ALL_TASKS
=
sorted
(
list
(
TASK_REGISTRY
))
# TODO: this doesn't seem to import right.
...
...
lm_eval/evaluator.py
View file @
487f7811
...
@@ -145,7 +145,8 @@ def evaluate(
...
@@ -145,7 +145,8 @@ def evaluate(
# for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
# for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
task
.
build_all_requests
(
limit
=
limit
)
task
.
build_all_requests
(
limit
=
limit
)
# aggregate Instances by LM method requested to get output.
# aggregate Instances by LM method requested to get output.
requests
[
task
.
OUTPUT_TYPE
].
extend
(
task
.
instances
)
reqtype
=
"loglikelihood"
if
task
.
OUTPUT_TYPE
==
"multiple_choice"
else
task
.
OUTPUT_TYPE
#TODO: this is hacky, fix in task.py
requests
[
reqtype
].
extend
(
task
.
instances
)
### Run LM on inputs, get all outputs ###
### Run LM on inputs, get all outputs ###
# execute each type of request
# execute each type of request
...
...
lm_eval/models/gpt2.py
View file @
487f7811
...
@@ -9,7 +9,7 @@ from lm_eval import utils
...
@@ -9,7 +9,7 @@ from lm_eval import utils
from
lm_eval.api.model
import
LM
,
register_model
from
lm_eval.api.model
import
LM
,
register_model
@
register_model
(
"hf-causal"
)
@
register_model
(
"hf-causal"
,
"gpt2"
)
class
HFLM
(
LM
):
class
HFLM
(
LM
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
lm_eval/tasks/arc.yaml
View file @
487f7811
dataset_path
:
ai2_arc
dataset_path
:
ai2_arc
dataset_name
:
ARC-Challenge
dataset_name
:
ARC-Challenge
output_type
:
multiple_choice
training_split
:
train
training_split
:
train
validation_split
:
validation
validation_split
:
validation
test_split
:
test
test_split
:
test
doc_to_text
:
"
Q:
{{question}}
\n
A:"
template_aliases
:
"
{%
set
answer_choices
=
choices['text']
%}{%
set
gold
=
choices.label.index(answerKey)
%}"
# set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_target
:
"
{%
set
answer_choices
=
doc['choices']['text']
%}{{answer_choices[int(doc['answerKey'])
-
1]}}"
doc_to_text
:
"
Question:
{{question}}
\n
Answer:"
doc_to_target
:
"
{{gold}}"
metric_list
:
metric_list
:
-
metric
:
exact_match
-
metric
:
acc
aggregation
:
mean
higher_is_better
:
true
-
metric
:
acc_norm
aggregation
:
mean
aggregation
:
mean
higher_is_better
:
true
higher_is_better
:
true
\ No newline at end of file
lm_eval/tasks/lambada.yaml
View file @
487f7811
dataset_path
:
EleutherAI/lambada_openai
dataset_path
:
EleutherAI/lambada_openai
dataset_name
:
default
dataset_name
:
default
output_type
:
"
loglikelihood
"
output_type
:
loglikelihood
test_split
:
test
test_split
:
test
template_aliases
:
"
{%
set
hypo
=
hypothesis
%}
"
template_aliases
:
"
"
doc_to_text
:
"
{{text.split('
')[:-1]|join('
')}}"
doc_to_text
:
"
{{text.split('
')[:-1]|join('
')}}"
doc_to_target
:
"
{{'
'+text.split('
')[-1]}}"
doc_to_target
:
"
{{'
'+text.split('
')[-1]}}"
should_decontaminate
:
true
should_decontaminate
:
true
...
@@ -12,5 +12,5 @@ metric_list:
...
@@ -12,5 +12,5 @@ metric_list:
aggregation
:
perplexity
aggregation
:
perplexity
higher_is_better
:
true
higher_is_better
:
true
-
metric
:
accuracy
-
metric
:
accuracy
aggregation
:
perplexity
aggregation
:
mean
higher_is_better
:
true
higher_is_better
:
true
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment