Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1444a36c
Unverified
Commit
1444a36c
authored
Apr 28, 2022
by
KhalidAlt
Committed by
GitHub
Apr 28, 2022
Browse files
Merge branch 'master' into master
parents
13676905
22155f7d
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
138 additions
and
37 deletions
+138
-37
lm_eval/base.py
lm_eval/base.py
+2
-4
lm_eval/models/gptj.py
lm_eval/models/gptj.py
+25
-10
lm_eval/models/t0.py
lm_eval/models/t0.py
+26
-8
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+5
-0
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+2
-2
lm_eval/tasks/crows_pairs_multilingual.py
lm_eval/tasks/crows_pairs_multilingual.py
+65
-0
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+2
-2
lm_eval/tasks/gem_asset_turk.py
lm_eval/tasks/gem_asset_turk.py
+2
-2
lm_eval/tasks/gem_webnlg.py
lm_eval/tasks/gem_webnlg.py
+2
-2
lm_eval/tasks/glue.py
lm_eval/tasks/glue.py
+2
-2
lm_eval/tasks/wino_bias.py
lm_eval/tasks/wino_bias.py
+2
-2
templates/new_task.py
templates/new_task.py
+3
-3
No files found.
lm_eval/base.py
View file @
1444a36c
...
@@ -694,11 +694,9 @@ class PromptSourceTask(Task):
...
@@ -694,11 +694,9 @@ class PromptSourceTask(Task):
def
stopping_criteria
(
self
)
->
Optional
[
str
]:
def
stopping_criteria
(
self
)
->
Optional
[
str
]:
"""Denote where the generation should end.
"""Denote where the generation should end.
For example, for coqa, this is '
\n
Q:' and for drop '.'.
By default, its "
\n
###
\n
".
By default, its None, meaning to generate up to max or EOT, whichever comes first.
"""
"""
return
None
return
"
\n
###
\n
"
def
max_generation_length
(
self
)
->
Optional
[
int
]:
def
max_generation_length
(
self
)
->
Optional
[
int
]:
"""Denote where the max length of the generation if it is obvious from the task."""
"""Denote where the max length of the generation if it is obvious from the task."""
...
...
lm_eval/models/gptj.py
View file @
1444a36c
...
@@ -8,6 +8,7 @@ class GPTJLM(BaseLM):
...
@@ -8,6 +8,7 @@ class GPTJLM(BaseLM):
self
,
self
,
device
=
"cuda"
,
device
=
"cuda"
,
batch_size
=
1
,
batch_size
=
1
,
parallelize
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -35,9 +36,11 @@ class GPTJLM(BaseLM):
...
@@ -35,9 +36,11 @@ class GPTJLM(BaseLM):
self
.
batch_size_per_gpu
=
batch_size
# todo: adaptive batch size
self
.
batch_size_per_gpu
=
batch_size
# todo: adaptive batch size
# TODO: fix multi-gpu
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
if
parallelize
:
# if gpus > 1:
self
.
gptj
.
parallelize
()
# self.gptj = nn.DataParallel(self.gptj)
self
.
_device
=
torch
.
device
(
'cuda:0'
)
else
:
self
.
gptj
.
to
(
self
.
_device
)
@
property
@
property
def
eot_token
(
self
):
def
eot_token
(
self
):
...
@@ -113,11 +116,23 @@ class GPTJLM(BaseLM):
...
@@ -113,11 +116,23 @@ class GPTJLM(BaseLM):
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
])
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
return
self
.
gptj
.
generate
(
if
num_fewshot
==
0
:
generations
=
self
.
gptj
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
self
.
eot_token_id
,
do_sample
=
False
,
)
else
:
generations
=
self
.
gptj
.
generate
(
context
,
context
,
max_length
=
max_length
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
do_sample
=
False
,
)
)
# Remove the context from the generations
return
generations
[
0
,
context
.
shape
[
1
]
:]
lm_eval/models/t0.py
View file @
1444a36c
...
@@ -56,7 +56,7 @@ class T0LM(BaseLM):
...
@@ -56,7 +56,7 @@ class T0LM(BaseLM):
@
property
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
):
return
self
.
tokenizer
.
model_max_length
return
256
@
property
@
property
def
batch_size
(
self
):
def
batch_size
(
self
):
...
@@ -94,6 +94,14 @@ class T0LM(BaseLM):
...
@@ -94,6 +94,14 @@ class T0LM(BaseLM):
inputs
,
targets
=
zip
(
*
chunk
)
inputs
,
targets
=
zip
(
*
chunk
)
# Fill in empty encoder inputs with eos_token
inputs
=
(
f
"
{
self
.
eot_token
}
"
if
len
(
input_
)
==
0
else
input_
for
input_
in
inputs
)
inputs_tok
=
self
.
tokenizer
(
inputs_tok
=
self
.
tokenizer
(
list
(
inputs
),
list
(
inputs
),
max_length
=
self
.
max_length
,
max_length
=
self
.
max_length
,
...
@@ -172,11 +180,21 @@ class T0LM(BaseLM):
...
@@ -172,11 +180,21 @@ class T0LM(BaseLM):
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
])
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
return
self
.
t0
.
generate
(
if
num_fewshot
==
0
:
generations
=
self
.
t0
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
self
.
eot_token_id
,
do_sample
=
False
,
)
else
:
generations
=
self
.
t0
.
generate
(
context
,
context
,
max_length
=
max_length
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
do_sample
=
False
,
)
)
return
generations
[
0
]
lm_eval/tasks/__init__.py
View file @
1444a36c
...
@@ -62,6 +62,7 @@ from . import gem_mlsum
...
@@ -62,6 +62,7 @@ from . import gem_mlsum
from
.
import
wino_bias
from
.
import
wino_bias
from
.
import
e2e_nlg_cleaned
from
.
import
e2e_nlg_cleaned
from
.
import
gem_asset_turk
from
.
import
gem_asset_turk
from
.
import
crows_pairs_multilingual
from
.
import
lama
from
.
import
lama
########################################
########################################
...
@@ -333,6 +334,10 @@ TASK_REGISTRY = {
...
@@ -333,6 +334,10 @@ TASK_REGISTRY = {
"wino_bias_type1_anti"
:
wino_bias
.
WinoBiasType1Anti
,
"wino_bias_type1_anti"
:
wino_bias
.
WinoBiasType1Anti
,
"wino_bias_type2_pro"
:
wino_bias
.
WinoBiasType2Pro
,
"wino_bias_type2_pro"
:
wino_bias
.
WinoBiasType2Pro
,
"wino_bias_type2_anti"
:
wino_bias
.
WinoBiasType2Anti
,
"wino_bias_type2_anti"
:
wino_bias
.
WinoBiasType2Anti
,
# Crows-Pairs
"crows_pairs_english"
:
crows_pairs_multilingual
.
CrowsPairsEnglish
,
"crows_pairs_french"
:
crows_pairs_multilingual
.
CrowsPairsFrench
,
}
}
...
...
lm_eval/tasks/coqa.py
View file @
1444a36c
...
@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask):
...
@@ -90,8 +90,8 @@ class CoQA(PromptSourceTask):
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
}
}
def
stopping_criteria
(
self
):
#
def stopping_criteria(self):
return
"
\n\n
"
#
return "\n\n"
# def construct_requests(self, doc, ctx):
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
# """Uses RequestFactory to construct Requests and returns an iterable of
...
...
lm_eval/tasks/crows_pairs_multilingual.py
0 → 100644
View file @
1444a36c
"""
French CrowS-Pairs: Extending a challenge dataset for measuring social bias in masked language models to a language other than English
https://hal.inria.fr/hal-03629677/file/ACLFinal.pdf
Measuring social biases in masked language models in English and French.
https://gitlab.inria.fr/french-crows-pairs/acl-2022-paper-data-and-code/-/tree/main
"""
from
lm_eval.base
import
PromptSourceTask
_CITATION
=
"""
\
@inproceedings{neveol2022french,
title={French CrowS-Pairs: Extending a challenge dataset for measuring social bias in masked language models to a language other than English},
author={N{
\'
e}v{
\'
e}ol, Aur{
\'
e}lie and Dupont, Yoann and Bezan{\c{c}}on, Julien and Fort, Kar{
\"
e}n},
booktitle={ACL 2022-60th Annual Meeting of the Association for Computational Linguistics},
year={2022}
"""
class
CrowsPairsEnglish
(
PromptSourceTask
):
VERSION
=
0
DATASET_PATH
=
"oskarvanderwal/crows_pairs_multilingual"
DATASET_NAME
=
"english"
def
has_training_docs
(
self
):
return
False
def
has_validation_docs
(
self
):
return
False
def
has_test_docs
(
self
):
return
True
def
training_docs
(
self
):
pass
def
validation_docs
(
self
):
pass
def
test_docs
(
self
):
if
self
.
has_test_docs
():
return
self
.
dataset
[
"test"
]
class
CrowsPairsFrench
(
PromptSourceTask
):
VERSION
=
0
DATASET_PATH
=
"oskarvanderwal/crows_pairs_multilingual"
DATASET_NAME
=
"french"
def
has_training_docs
(
self
):
return
False
def
has_validation_docs
(
self
):
return
False
def
has_test_docs
(
self
):
return
True
def
training_docs
(
self
):
pass
def
validation_docs
(
self
):
pass
def
test_docs
(
self
):
if
self
.
has_test_docs
():
return
self
.
dataset
[
"test"
]
lm_eval/tasks/drop.py
View file @
1444a36c
...
@@ -92,8 +92,8 @@ class DROP(PromptSourceTask):
...
@@ -92,8 +92,8 @@ class DROP(PromptSourceTask):
# """
# """
# conts = [rf.greedy_until(ctx, ["."])]
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
# return conts
def
stopping_criteria
(
self
):
#
def stopping_criteria(self):
return
"."
#
return "."
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
...
...
lm_eval/tasks/gem_asset_turk.py
View file @
1444a36c
...
@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask):
...
@@ -78,8 +78,8 @@ class AssetTurk(PromptSourceTask):
def
test_docs
(
self
):
def
test_docs
(
self
):
return
self
.
dataset
[
str
(
self
.
SPLIT
)]
return
self
.
dataset
[
str
(
self
.
SPLIT
)]
def
stopping_criteria
(
self
):
#
def stopping_criteria(self):
return
None
#
return None
def
max_generation_length
(
self
):
def
max_generation_length
(
self
):
return
200
return
200
...
...
lm_eval/tasks/gem_webnlg.py
View file @
1444a36c
...
@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask):
...
@@ -70,8 +70,8 @@ class WebNLG(PromptSourceTask):
else
:
else
:
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
#
def stopping_criteria(self):
return
None
#
return None
def
max_generation_length
(
self
):
def
max_generation_length
(
self
):
return
250
return
250
...
...
lm_eval/tasks/glue.py
View file @
1444a36c
...
@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask):
...
@@ -236,8 +236,8 @@ class MRPC(PromptSourceTask):
def
has_test_docs
(
self
):
def
has_test_docs
(
self
):
return
False
return
False
def
stopping_criteria
(
self
):
#
def stopping_criteria(self):
return
"
\n
"
#
return "\n
###\n
"
def
training_docs
(
self
):
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
if
self
.
_training_docs
is
None
:
...
...
lm_eval/tasks/wino_bias.py
View file @
1444a36c
...
@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask):
...
@@ -54,8 +54,8 @@ class WinoBias(PromptSourceTask):
def
test_docs
(
self
):
def
test_docs
(
self
):
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
#
def stopping_criteria(self):
return
"
\n
"
#
return "\n"
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
...
...
templates/new_task.py
View file @
1444a36c
...
@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask):
...
@@ -73,10 +73,10 @@ class NewTask(PromptSourceTask):
return
self
.
dataset
[
"test"
]
return
self
.
dataset
[
"test"
]
def
stopping_criteria
(
self
):
def
stopping_criteria
(
self
):
#
TODO: Denote the string where the generation should be split
.
#
Only define this method when you want to control few-shot generations on specific tokens
.
#
For example, for `coqa`, this is '\nQ:' and for `drop` '.
'.
#
The default is set to '\n###\n
'.
# NOTE: You may delete this function if the task does not required generation.
# NOTE: You may delete this function if the task does not required generation.
return
None
return
"
\n
###
\n
"
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment