Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
4941a8bb
Unverified
Commit
4941a8bb
authored
Apr 26, 2022
by
Jonathan Tow
Committed by
GitHub
Apr 26, 2022
Browse files
Merge pull request #1 from cjlovering/cjlovering/gen_max_len
Add optional max length to generation.
parents
02ec7889
9384ec91
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
6 deletions
+20
-6
lm_eval/base.py
lm_eval/base.py
+20
-6
No files found.
lm_eval/base.py
View file @
4941a8bb
import
abc
import
abc
from
typing
import
Iterable
from
typing
import
Iterable
,
Optional
import
promptsource
import
promptsource
import
numpy
as
np
import
numpy
as
np
...
@@ -348,17 +348,25 @@ class BaseLM(LM):
...
@@ -348,17 +348,25 @@ class BaseLM(LM):
for
context
,
until
in
tqdm
(
reord
.
get_reordered
()):
for
context
,
until
in
tqdm
(
reord
.
get_reordered
()):
if
isinstance
(
until
,
str
):
if
isinstance
(
until
,
str
):
until
=
[
until
]
until
=
[
until
]
max_length
=
None
elif
isinstance
(
until
,
list
)
and
len
(
until
)
==
2
:
until
,
max_length
=
[
until
[
0
]],
until
[
1
]
elif
isinstance
(
until
,
list
):
max_length
=
None
# TODO: Come back to for generation `eos`.
primary_until
=
self
.
tok_encode
(
until
[
0
])
primary_until
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
(
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
).
to
(
self
.
device
)
).
to
(
self
.
device
)
if
max_length
is
not
None
:
max_length
=
min
(
max_length
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
)
else
:
max_length
=
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
cont
=
self
.
_model_generate
(
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_
g
en
_toks
,
max_
l
en
gth
,
torch
.
tensor
(
primary_until
),
torch
.
tensor
(
primary_until
),
)
)
...
@@ -652,7 +660,7 @@ class PromptSourceTask(Task):
...
@@ -652,7 +660,7 @@ class PromptSourceTask(Task):
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
self
.
prompt
=
prompt
self
.
prompt
=
prompt
def
stopping_criteria
(
self
):
def
stopping_criteria
(
self
)
->
Optional
[
str
]
:
"""Denote where the generation should end.
"""Denote where the generation should end.
For example, for coqa, this is '
\n
Q:' and for drop '.'.
For example, for coqa, this is '
\n
Q:' and for drop '.'.
...
@@ -661,6 +669,10 @@ class PromptSourceTask(Task):
...
@@ -661,6 +669,10 @@ class PromptSourceTask(Task):
"""
"""
return
None
return
None
def
max_generation_length
(
self
)
->
Optional
[
int
]:
"""Denote where the max length of the generation if it is obvious from the task."""
return
None
def
is_generation_task
(
self
):
def
is_generation_task
(
self
):
return
(
return
(
"BLEU"
in
self
.
prompt
.
metadata
.
metrics
"BLEU"
in
self
.
prompt
.
metadata
.
metrics
...
@@ -718,7 +730,9 @@ class PromptSourceTask(Task):
...
@@ -718,7 +730,9 @@ class PromptSourceTask(Task):
_requests
.
append
(
ll_answer_choice
)
_requests
.
append
(
ll_answer_choice
)
else
:
else
:
# TODO(Albert): What is the stop symbol? Is it model specific?
# TODO(Albert): What is the stop symbol? Is it model specific?
cont_request
=
rf
.
greedy_until
(
ctx
,
[
self
.
stopping_criteria
()])
cont_request
=
rf
.
greedy_until
(
ctx
,
[
self
.
stopping_criteria
(),
self
.
max_generation_length
()]
)
_requests
.
append
(
cont_request
)
_requests
.
append
(
cont_request
)
return
_requests
return
_requests
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment