Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
5e59320b
"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "32c30c3bf9c10189a0e3bae6327a5085a05a5143"
Commit
5e59320b
authored
Apr 28, 2022
by
Tian Yun
Browse files
Modified stopping criteria for T5 and GPT-2
parent
c46ff9e4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
19 deletions
+47
-19
lm_eval/base.py
lm_eval/base.py
+10
-2
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-1
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+17
-8
lm_eval/models/t5.py
lm_eval/models/t5.py
+18
-8
No files found.
lm_eval/base.py
View file @
5e59320b
...
@@ -353,11 +353,13 @@ class BaseLM(LM):
...
@@ -353,11 +353,13 @@ class BaseLM(LM):
for
context
,
request_args
in
tqdm
(
reord
.
get_reordered
()):
for
context
,
request_args
in
tqdm
(
reord
.
get_reordered
()):
stopping_criteria
=
request_args
[
"stopping_criteria"
]
stopping_criteria
=
request_args
[
"stopping_criteria"
]
max_generation_length
=
request_args
[
"max_generation_length"
]
max_generation_length
=
request_args
[
"max_generation_length"
]
num_fewshot
=
request_args
[
"num_fewshot"
]
assert
isinstance
(
stopping_criteria
,
str
)
or
stopping_criteria
is
None
assert
isinstance
(
stopping_criteria
,
str
)
or
stopping_criteria
is
None
assert
(
assert
(
isinstance
(
max_generation_length
,
int
)
or
max_generation_length
is
None
isinstance
(
max_generation_length
,
int
)
or
max_generation_length
is
None
)
)
assert
isinstance
(
num_fewshot
,
int
)
or
num_fewshot
is
None
if
stopping_criteria
is
None
:
if
stopping_criteria
is
None
:
until
=
[
self
.
eot_token
]
until
=
[
self
.
eot_token
]
...
@@ -382,6 +384,7 @@ class BaseLM(LM):
...
@@ -382,6 +384,7 @@ class BaseLM(LM):
context_enc
,
context_enc
,
max_length
,
max_length
,
torch
.
tensor
(
primary_until
),
torch
.
tensor
(
primary_until
),
num_fewshot
,
)
)
s
=
self
.
tok_decode
(
cont
.
tolist
())
s
=
self
.
tok_decode
(
cont
.
tolist
())
...
@@ -536,7 +539,7 @@ class Task(abc.ABC):
...
@@ -536,7 +539,7 @@ class Task(abc.ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
,
args
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
...
@@ -546,6 +549,8 @@ class Task(abc.ABC):
...
@@ -546,6 +549,8 @@ class Task(abc.ABC):
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
"""
"""
pass
pass
...
@@ -724,7 +729,7 @@ class PromptSourceTask(Task):
...
@@ -724,7 +729,7 @@ class PromptSourceTask(Task):
text
,
_
=
self
.
prompt
.
apply
(
doc
)
text
,
_
=
self
.
prompt
.
apply
(
doc
)
return
text
return
text
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
,
args
):
"""Uses RequestFactory to construct Requests and returns an iterable of
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
...
@@ -734,6 +739,8 @@ class PromptSourceTask(Task):
...
@@ -734,6 +739,8 @@ class PromptSourceTask(Task):
The context string, generated by fewshot_context. This includes the natural
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
language description, as well as the few shot examples, and the question
part of the document for `doc`.
part of the document for `doc`.
:param args: dict
The specifics of the context, including number of few shots.
"""
"""
_requests
=
[]
_requests
=
[]
answer_choices_list
=
self
.
prompt
.
get_answer_choices_list
(
doc
)
answer_choices_list
=
self
.
prompt
.
get_answer_choices_list
(
doc
)
...
@@ -749,6 +756,7 @@ class PromptSourceTask(Task):
...
@@ -749,6 +756,7 @@ class PromptSourceTask(Task):
request_args
=
{
request_args
=
{
"stopping_criteria"
:
self
.
stopping_criteria
(),
"stopping_criteria"
:
self
.
stopping_criteria
(),
"max_generation_length"
:
self
.
max_generation_length
(),
"max_generation_length"
:
self
.
max_generation_length
(),
"num_fewshot"
:
args
[
"num_fewshot"
],
}
}
cont_request
=
rf
.
greedy_until
(
ctx
,
request_args
)
cont_request
=
rf
.
greedy_until
(
ctx
,
request_args
)
_requests
.
append
(
cont_request
)
_requests
.
append
(
cont_request
)
...
...
lm_eval/evaluator.py
View file @
5e59320b
...
@@ -206,7 +206,8 @@ def evaluate(
...
@@ -206,7 +206,8 @@ def evaluate(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
)
fewshotex_logging_info
[
"doc_id"
]
=
original_doc_id
fewshotex_logging_info
[
"doc_id"
]
=
original_doc_id
reqs
=
task
.
construct_requests
(
doc
,
ctx
)
args
=
{
"num_fewshot"
:
num_fewshot
}
reqs
=
task
.
construct_requests
(
doc
,
ctx
,
args
)
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
if
not
isinstance
(
reqs
,
(
list
,
tuple
)):
reqs
=
[
reqs
]
reqs
=
[
reqs
]
for
i
,
req
in
enumerate
(
reqs
):
for
i
,
req
in
enumerate
(
reqs
):
...
...
lm_eval/models/gpt2.py
View file @
5e59320b
...
@@ -149,14 +149,23 @@ class HFLM(BaseLM):
...
@@ -149,14 +149,23 @@ class HFLM(BaseLM):
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
])
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
# stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
generations
=
self
.
gpt2
.
generate
(
context
,
if
num_fewshot
==
0
:
max_length
=
max_length
,
generations
=
self
.
gpt2
.
generate
(
# stopping_criteria=stopping_criteria,
context
,
do_sample
=
False
,
max_length
=
max_length
,
)
eos_token_id
=
self
.
eot_token_id
,
do_sample
=
False
,
)
else
:
generations
=
self
.
gpt2
.
generate
(
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
)
# Remove the context from the generations
# Remove the context from the generations
return
generations
[
0
,
context
.
shape
[
1
]
:]
return
generations
[
0
,
context
.
shape
[
1
]
:]
...
...
lm_eval/models/t5.py
View file @
5e59320b
...
@@ -186,11 +186,21 @@ class T5LM(BaseLM):
...
@@ -186,11 +186,21 @@ class T5LM(BaseLM):
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
EOSCriteria
(
self
.
tokenizer
.
eos_token
)
])
])
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
):
def
_model_generate
(
self
,
context
,
max_length
,
stopping_criteria_ids
,
num_fewshot
):
# stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria_ids
)
return
self
.
t5
.
generate
(
context
,
if
num_fewshot
==
0
:
max_length
=
max_length
,
generations
=
self
.
t5
.
generate
(
# stopping_criteria=stopping_criteria,
context
,
do_sample
=
False
,
max_length
=
max_length
,
)[
0
]
eos_token_id
=
self
.
eot_token_id
,
do_sample
=
False
,
)
else
:
generations
=
self
.
t5
.
generate
(
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
)
return
generations
[
0
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment