Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
b4ad893c
Commit
b4ad893c
authored
Apr 25, 2022
by
ken
Browse files
Merge master
parents
8c83a821
20820c3c
Changes
35
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
167 additions
and
20 deletions
+167
-20
.github/workflows/python-app.yml
.github/workflows/python-app.yml
+1
-1
docs/task_guide.md
docs/task_guide.md
+1
-1
lm_eval/base.py
lm_eval/base.py
+165
-18
lm_eval/datasets/arithmetic/__init__.py
lm_eval/datasets/arithmetic/__init__.py
+0
-0
lm_eval/datasets/asdiv/__init__.py
lm_eval/datasets/asdiv/__init__.py
+0
-0
lm_eval/datasets/coqa/__init__.py
lm_eval/datasets/coqa/__init__.py
+0
-0
lm_eval/datasets/drop/__init__.py
lm_eval/datasets/drop/__init__.py
+0
-0
lm_eval/datasets/gsm8k/__init__.py
lm_eval/datasets/gsm8k/__init__.py
+0
-0
lm_eval/datasets/headqa/__init__.py
lm_eval/datasets/headqa/__init__.py
+0
-0
lm_eval/datasets/hendrycks_ethics/__init__.py
lm_eval/datasets/hendrycks_ethics/__init__.py
+0
-0
lm_eval/datasets/hendrycks_math/__init__.py
lm_eval/datasets/hendrycks_math/__init__.py
+0
-0
lm_eval/datasets/lambada/__init__.py
lm_eval/datasets/lambada/__init__.py
+0
-0
lm_eval/datasets/logiqa/__init__.py
lm_eval/datasets/logiqa/__init__.py
+0
-0
lm_eval/datasets/mutual/__init__.py
lm_eval/datasets/mutual/__init__.py
+0
-0
lm_eval/datasets/pile/__init__.py
lm_eval/datasets/pile/__init__.py
+0
-0
lm_eval/datasets/quac/__init__.py
lm_eval/datasets/quac/__init__.py
+0
-0
lm_eval/datasets/sat_analogies/__init__.py
lm_eval/datasets/sat_analogies/__init__.py
+0
-0
lm_eval/datasets/triviaqa/__init__.py
lm_eval/datasets/triviaqa/__init__.py
+0
-0
lm_eval/datasets/truthfulqa/__init__.py
lm_eval/datasets/truthfulqa/__init__.py
+0
-0
lm_eval/datasets/unscramble/__init__.py
lm_eval/datasets/unscramble/__init__.py
+0
-0
No files found.
.github/workflows/python-app.yml
View file @
b4ad893c
...
@@ -32,7 +32,7 @@ jobs:
...
@@ -32,7 +32,7 @@ jobs:
run
:
|
run
:
|
python -m pip install --upgrade pip
python -m pip install --upgrade pip
pip install flake8 pytest pytest-cov
pip install flake8 pytest pytest-cov
pip install -e .
pip install -e .
[dev]
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
-
name
:
Lint with flake8
-
name
:
Lint with flake8
run
:
|
run
:
|
...
...
docs/task_guide.md
View file @
b4ad893c
...
@@ -11,7 +11,7 @@ If you haven't already, go ahead and fork the main repo, clone it, create a bran
...
@@ -11,7 +11,7 @@ If you haven't already, go ahead and fork the main repo, clone it, create a bran
git clone https://github.com/<YOUR-USERNAME>/lm-evaluation-harness.git
git clone https://github.com/<YOUR-USERNAME>/lm-evaluation-harness.git
cd
lm-evaluation-harness
cd
lm-evaluation-harness
git checkout
-b
<task-name>
git checkout
-b
<task-name>
pip
install
-
r
requirements.txt
pip
install
-
e
".[dev]"
```
```
## Creating Your Task File
## Creating Your Task File
...
...
lm_eval/base.py
View file @
b4ad893c
...
@@ -121,6 +121,11 @@ class LM(abc.ABC):
...
@@ -121,6 +121,11 @@ class LM(abc.ABC):
class
BaseLM
(
LM
):
class
BaseLM
(
LM
):
@
property
@
abstractmethod
def
eot_token
(
self
):
pass
@
property
@
property
@
abstractmethod
@
abstractmethod
def
eot_token_id
(
self
):
def
eot_token_id
(
self
):
...
@@ -354,8 +359,15 @@ class BaseLM(LM):
...
@@ -354,8 +359,15 @@ class BaseLM(LM):
isinstance
(
max_generation_length
,
int
)
or
max_generation_length
is
None
isinstance
(
max_generation_length
,
int
)
or
max_generation_length
is
None
)
)
if
stopping_criteria
is
None
:
until
=
[
self
.
eot_token
]
else
:
until
=
[
stopping_criteria
]
until
=
[
stopping_criteria
]
primary_until
=
self
.
tok_encode
(
until
[
0
])
primary_until
=
self
.
tok_encode
(
until
[
0
])
if
len
(
primary_until
)
==
0
:
primary_until
=
torch
.
tensor
([
self
.
eot_token_id
])
context_enc
=
torch
.
tensor
(
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
).
to
(
self
.
device
)
).
to
(
self
.
device
)
...
@@ -633,14 +645,18 @@ class Task(abc.ABC):
...
@@ -633,14 +645,18 @@ class Task(abc.ABC):
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
fewshotex
=
[
x
for
x
in
fewshotex
if
x
!=
doc
][:
num_fewshot
]
# See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
# for justification of this separator.
example_separator
=
"
\n
###
\n
"
labeled_examples
=
(
labeled_examples
=
(
"
\n\n
"
.
join
(
example_separator
.
join
(
[
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
for
doc
in
fewshotex
]
]
)
)
+
"
\n\n
"
+
example_separator
)
)
example
=
self
.
doc_to_text
(
doc
)
example
=
self
.
doc_to_text
(
doc
)
...
@@ -654,11 +670,21 @@ class PromptSourceTask(Task):
...
@@ -654,11 +670,21 @@ class PromptSourceTask(Task):
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
"""
"""
CONFIGURED_PS_METRICS
=
set
([
"Accuracy"
,
"BLEU"
,
"ROUGE"
])
CONFIGURED_RANKED_CHOICE_PS_METRICS
=
set
([
"Accuracy"
])
CONFIGURED_GENERATION_PS_METRICS
=
set
([
"BLEU"
,
"ROUGE"
])
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
prompt
=
None
):
SPLIT
=
None
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
prompt
=
None
,
save_examples
=
True
,
):
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
self
.
prompt
=
prompt
self
.
prompt
=
prompt
self
.
save_examples
=
save_examples
def
stopping_criteria
(
self
)
->
Optional
[
str
]:
def
stopping_criteria
(
self
)
->
Optional
[
str
]:
"""Denote where the generation should end.
"""Denote where the generation should end.
...
@@ -752,12 +778,11 @@ class PromptSourceTask(Task):
...
@@ -752,12 +778,11 @@ class PromptSourceTask(Task):
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
metric
in
self
.
CONFIGURED_
RANKED_CHOICE_
PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"Accuracy"
:
if
metric
==
"Accuracy"
:
out
[
"acc"
]
=
pred
==
target
out
[
"acc"
]
=
pred
==
target
# TODO: Add metrics here.
# TODO: Add metrics here.
return
out
else
:
else
:
# If not, then this is a generation prompt.
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
# NOTE: In the future, target will be a list of strings.
...
@@ -765,11 +790,11 @@ class PromptSourceTask(Task):
...
@@ -765,11 +790,11 @@ class PromptSourceTask(Task):
out
=
{}
out
=
{}
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
metric
in
self
.
CONFIGURED_
GENERATION_
PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"BLEU"
:
if
metric
==
"BLEU"
:
out
[
"bleu"
]
=
(
target
,
pred
)
out
[
"bleu"
]
=
(
target
,
pred
)
if
metric
==
"ROUGE"
:
el
if
metric
==
"ROUGE"
:
# TODO: This computes all rouge sub-metrics. Find a generic
# TODO: This computes all rouge sub-metrics. Find a generic
# way to handle user specified rouge sub-metrics to avoid extra
# way to handle user specified rouge sub-metrics to avoid extra
# compute.
# compute.
...
@@ -778,15 +803,21 @@ class PromptSourceTask(Task):
...
@@ -778,15 +803,21 @@ class PromptSourceTask(Task):
rouge_scores
=
utils
.
flatten
(
rouge_scores
)
rouge_scores
=
utils
.
flatten
(
rouge_scores
)
# Merge all the rouge-type scores into the `out` dict.
# Merge all the rouge-type scores into the `out` dict.
out
=
{
**
out
,
**
rouge_scores
}
out
=
{
**
out
,
**
rouge_scores
}
print
(
out
)
# TODO: Wrap process results s.t. override impl do not
# override the save examples.
if
self
.
save_examples
:
example
=
{
"pred"
:
pred
,
"target"
:
target
,
"answer_choices_list"
:
answer_choices_list
,
}
return
out
,
example
return
out
return
out
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
out
=
{}
out
=
{}
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"Accuracy"
:
if
metric
==
"Accuracy"
:
out
[
"acc"
]
=
True
out
[
"acc"
]
=
True
if
metric
==
"BLEU"
:
if
metric
==
"BLEU"
:
...
@@ -813,9 +844,6 @@ class PromptSourceTask(Task):
...
@@ -813,9 +844,6 @@ class PromptSourceTask(Task):
def
aggregation
(
self
):
def
aggregation
(
self
):
out
=
{}
out
=
{}
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"Accuracy"
:
if
metric
==
"Accuracy"
:
out
[
"acc"
]
=
mean
out
[
"acc"
]
=
mean
if
metric
==
"BLEU"
:
if
metric
==
"BLEU"
:
...
@@ -839,6 +867,125 @@ class PromptSourceTask(Task):
...
@@ -839,6 +867,125 @@ class PromptSourceTask(Task):
out
[
"rougeLsum_fmeasure"
]
=
mean
out
[
"rougeLsum_fmeasure"
]
=
mean
return
out
return
out
def
fewshot_examples
(
self
,
k
,
rnd
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
training_docs
())
return
self
.
_get_fewshot_examples
(
self
.
_training_docs
,
k
,
rnd
)
def
_get_fewshot_examples
(
self
,
docs
,
k
,
rnd
):
fewshot_idx
=
rnd
.
sample
(
list
(
np
.
arange
(
len
(
docs
))),
k
)
return
[
docs
[
idx
]
for
idx
in
fewshot_idx
],
[
int
(
idx
)
for
idx
in
fewshot_idx
]
@
utils
.
positional_deprecated
def
fewshot_context
(
self
,
doc
,
num_fewshot
,
provide_description
=
None
,
rnd
=
None
,
description
=
None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert
(
rnd
is
not
None
),
"A `random.Random` generator argument must be provided to `rnd`"
assert
not
provide_description
,
(
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
)
if
provide_description
is
not
None
:
# nudge people to not specify it at all
print
(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)
description
=
description
+
"
\n\n
"
if
description
else
""
if
num_fewshot
==
0
:
labeled_examples
=
""
fewshotex
,
fewshotidx
,
fewshotsource
=
[],
[],
None
else
:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if
self
.
has_training_docs
():
fewshotex
,
fewshotidx
=
self
.
fewshot_examples
(
k
=
num_fewshot
,
rnd
=
rnd
)
fewshotsource
=
"train"
else
:
if
self
.
_fewshot_docs
is
None
:
self
.
_fewshot_docs
=
list
(
self
.
validation_docs
()
if
self
.
has_validation_docs
()
else
self
.
test_docs
()
)
if
self
.
has_validation_docs
():
fewshotsource
=
"val"
elif
self
.
test_docs
():
fewshotsource
=
"test"
fewshotex
,
fewshotidx
=
self
.
_get_fewshot_examples
(
self
.
_fewshot_docs
,
k
=
num_fewshot
+
1
,
rnd
=
rnd
)
fewshotex
,
fewshotidx
=
[
(
shot
,
idx
)
for
shot
,
idx
in
zip
(
fewshotex
,
fewshotidx
)
if
shot
!=
doc
]
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex
,
fewshotidx
=
(
fewshotex
[:
num_fewshot
],
fewshotidx
[:
num_fewshot
],
)
# See Webson & Pavlick (2022) https://arxiv.org/pdf/2109.01247.pdf
# for justification of this separator.
example_separator
=
"
\n
###
\n
"
labeled_examples
=
(
example_separator
.
join
(
[
self
.
doc_to_text
(
doc
)
+
self
.
doc_to_target
(
doc
)
for
doc
in
fewshotex
]
)
+
example_separator
)
example
=
self
.
doc_to_text
(
doc
)
ctx
=
description
+
labeled_examples
+
example
return
(
ctx
,
{
"fewshot_idx"
:
fewshotidx
,
"fewshot_source"
:
fewshotsource
,
"fewshot_num"
:
num_fewshot
,
"ctx"
:
ctx
,
},
)
def
get_logging_info
(
self
):
return
{
"fixed_answer_choice_list"
:
self
.
prompt
.
get_fixed_answer_choices_list
(),
"dataset_path"
:
self
.
DATASET_PATH
,
"dataset_name"
:
self
.
DATASET_NAME
,
"subset"
:
self
.
SPLIT
,
"prompt_name"
:
self
.
prompt
.
get_name
(),
"prompt_id"
:
self
.
prompt
.
get_id
(),
"prompt_jinja"
:
self
.
prompt
.
jinja
,
"prompt_original_task"
:
self
.
prompt
.
metadata
.
original_task
,
# Placeholder for comment in post-processing.
"comment"
:
""
,
}
class
MultipleChoiceTask
(
Task
):
class
MultipleChoiceTask
(
Task
):
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
...
...
lm_eval/datasets/arithmetic/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/asdiv/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/coqa/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/drop/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/gsm8k/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/headqa/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/hendrycks_ethics/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/hendrycks_math/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/lambada/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/logiqa/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/mutual/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/pile/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/quac/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/sat_analogies/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/triviaqa/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/truthfulqa/__init__.py
0 → 100644
View file @
b4ad893c
lm_eval/datasets/unscramble/__init__.py
0 → 100644
View file @
b4ad893c
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment