Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
21d897db
Commit
21d897db
authored
Apr 26, 2022
by
cjlovering
Browse files
Updated the requests so that its easier to understand.
parent
4f85bcf9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
15 deletions
+19
-15
lm_eval/base.py
lm_eval/base.py
+19
-15
No files found.
lm_eval/base.py
View file @
21d897db
...
...
@@ -345,25 +345,27 @@ class BaseLM(LM):
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
context
,
until
in
tqdm
(
reord
.
get_reordered
()):
if
isinstance
(
until
,
str
):
until
=
[
until
]
max_length
=
None
elif
isinstance
(
until
,
list
)
and
len
(
until
)
==
2
:
until
,
max_length
=
[
until
[
0
]],
until
[
1
]
elif
isinstance
(
until
,
list
):
max_length
=
None
for
context
,
request_args
in
tqdm
(
reord
.
get_reordered
()):
stopping_criteria
=
request_args
[
"stopping_criteria"
]
max_generation_length
=
request_args
[
"max_generation_length"
]
assert
isinstance
(
stopping_criteria
,
str
)
or
stopping_criteria
is
None
assert
(
isinstance
(
max_generation_length
,
int
)
or
max_generation_length
is
None
)
until
=
[
stopping_criteria
]
primary_until
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]
).
to
(
self
.
device
)
if
max_length
is
not
None
:
max_length
=
min
(
max_length
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
)
else
:
if
max_generation_length
is
None
:
max_length
=
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
else
:
max_length
=
min
(
max_generation_length
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
)
cont
=
self
.
_model_generate
(
context_enc
,
max_length
,
...
...
@@ -720,9 +722,11 @@ class PromptSourceTask(Task):
else
:
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
cont_request
=
rf
.
greedy_until
(
ctx
,
[
self
.
stopping_criteria
(),
self
.
max_generation_length
()]
)
request_args
=
{
"stopping_criteria"
:
self
.
stopping_criteria
(),
"max_generation_length"
:
self
.
max_generation_length
(),
}
cont_request
=
rf
.
greedy_until
(
ctx
,
request_args
)
_requests
.
append
(
cont_request
)
return
_requests
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment