Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
5e59320b
Commit
5e59320b
authored
Apr 28, 2022
by
Tian Yun
Browse files
Modified stopping criteria for T5 and GPT-2
parent
c46ff9e4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
19 deletions
+47
-19
lm_eval/base.py
lm_eval/base.py
+10
-2
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-1
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+17
-8
lm_eval/models/t5.py
lm_eval/models/t5.py
+18
-8
No files found.
lm_eval/base.py
View file @
5e59320b
...
...
@@ -353,11 +353,13 @@ class BaseLM(LM):
for
context
,
request_args
in
tqdm
(
reord
.
get_reordered
()):
stopping_criteria
=
request_args
[
"stopping_criteria"
]
max_generation_length
=
request_args
[
"max_generation_length"
]
num_fewshot
=
request_args
[
"num_fewshot"
]
assert
isinstance
(
stopping_criteria
,
str
)
or
stopping_criteria
is
None
assert
(
isinstance
(
max_generation_length
,
int
)
or
max_generation_length
is
None
)
assert
isinstance
(
num_fewshot
,
int
)
or
num_fewshot
is
None
if
stopping_criteria
is
None
:
until
=
[
self
.
eot_token
]
...
...
@@ -382,6 +384,7 @@ class BaseLM(LM):
context_enc
,
max_length
,
torch
.
tensor
(
primary_until
),
num_fewshot
,
)
s
=
self
.
tok_decode
(
cont
.
tolist
())
...
...
@@ -536,7 +539,7 @@ class Task(abc.ABC):
pass
@
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
,
args
):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
...
...
@@ -546,6 +549,8 @@ class Task(abc.ABC):
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
"""
pass
...
...
@@ -724,7 +729,7 @@ class PromptSourceTask(Task):
text
,
_
=
self
.
prompt
.
apply
(
doc
)
return
text
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
,
args
):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
...
...
@@ -734,6 +739,8 @@ class PromptSourceTask(Task):
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
"""
_requests
=
[]
answer_choices_list
=
self
.
prompt
.
get_answer_choices_list
(
doc
)
...
...
@@ -749,6 +756,7 @@ class PromptSourceTask(Task):
request_args
=
{
"stopping_criteria"
:
self
.
stopping_criteria
(),
"max_generation_length"
:
self
.
max_generation_length
(),
"num_fewshot"
:
args
[
"num_fewshot"
],
}
cont_request
=
rf
.
greedy_until
(
ctx
,
request_args
)
_requests
.
append
(
cont_request
)
...
...
lm_eval/evaluator.py
View file @
5e59320b
...
...
@@ -206,7 +206,8 @@ def evaluate(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
fewshotex_logging_info
[
"doc_id"
]
=
original_doc_id
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
args
=
{
"num_fewshot"
:
num_fewshot
}
reqs
=
task
.
construct_requests
(
doc
,
ctx
,
args
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
reqs
=
[
reqs
]
for
i
,
req
in
enumerate
(
reqs
):
...
...
lm_eval/models/gpt2.py
View file @
5e59320b
...
...
@@ -149,14 +149,23 @@ class HFLM(BaseLM):
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
# stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
generations
=
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
# stopping_criteria=stopping_criteria,
do_sample
=
False
,
)
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
if
num_fewshot
==
0
:
generations
=
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
self
.
eot_token_id
,
do_sample
=
False
,
)
else
:
generations
=
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
)
# Remove the context from the generations
return
generations
[
0
,
context
.
shape
[
1
]
:]
...
...
lm_eval/models/t5.py
View file @
5e59320b
...
...
@@ -186,11 +186,21 @@ class T5LM(BaseLM):
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
# stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return
self
.
t5
.
generate
(
context
,
max_length
=
max_length
,
# stopping_criteria=stopping_criteria,
do_sample
=
False
,
)[
0
]
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
if
num_fewshot
==
0
:
generations
=
self
.
t5
.
generate
(
context
,
max_length
=
max_length
,
eos_token_id
=
self
.
eot_token_id
,
do_sample
=
False
,
)
else
:
generations
=
self
.
t5
.
generate
(
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
)
return
generations
[
0
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment