Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
c9c141d2
Commit
c9c141d2
authored
Jun 29, 2023
by
haileyschoelkopf
Browse files
add err handling for multi-tok stopseq
parent
72b7f0c0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
6 deletions
+27
-6
lm_eval/base.py
lm_eval/base.py
+27
-6
No files found.
lm_eval/base.py
View file @
c9c141d2
...
...
@@ -176,7 +176,9 @@ class BaseLM(LM):
def
_detect_batch_size
(
self
,
requests
=
None
,
pos
=
0
):
if
requests
:
_
,
context_enc
,
continuation_enc
=
requests
[
pos
]
max_length
=
len
((
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
])
max_length
=
len
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
]
)
else
:
max_length
=
self
.
max_length
...
...
@@ -212,7 +214,9 @@ class BaseLM(LM):
for
context
,
continuation
in
requests
:
if
context
==
""
:
# end of text as context
context_enc
,
continuation_enc
=
[
self
.
eot_token_id
],
self
.
tok_encode
(
continuation
)
context_enc
,
continuation_enc
=
[
self
.
eot_token_id
],
self
.
tok_encode
(
continuation
)
else
:
context_enc
,
continuation_enc
=
self
.
_encode_pair
(
context
,
continuation
)
...
...
@@ -290,15 +294,23 @@ class BaseLM(LM):
sched
=
pos
//
int
(
n_reordered_requests
/
self
.
batch_schedule
)
if
sched
in
self
.
batch_sizes
:
return
self
.
batch_sizes
[
sched
]
print
(
f
"Passed argument batch_size = auto:
{
self
.
batch_schedule
}
. Detecting largest batch size"
)
print
(
f
"Passed argument batch_size = auto:
{
self
.
batch_schedule
}
. Detecting largest batch size"
)
self
.
batch_sizes
[
sched
]
=
self
.
_detect_batch_size
(
reordered_requests
,
pos
)
print
(
f
"Determined largest batch size:
{
self
.
batch_sizes
[
sched
]
}
"
)
return
self
.
batch_sizes
[
sched
]
for
chunk
in
utils
.
chunks
(
tqdm
(
reordered_requests
,
disable
=
disable_tqdm
),
n
=
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
override_bs
if
override_bs
is
not
None
else
0
,
fn
=
_batch_scheduler
if
self
.
batch_size
==
"auto"
and
n_reordered_requests
>
0
else
None
,
n
=
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
override_bs
if
override_bs
is
not
None
else
0
,
fn
=
_batch_scheduler
if
self
.
batch_size
==
"auto"
and
n_reordered_requests
>
0
else
None
,
):
inps
=
[]
cont_toks_list
=
[]
...
...
@@ -411,13 +423,22 @@ class BaseLM(LM):
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
warn_stop_seq
=
False
for
context
,
request_args
in
tqdm
(
re_ord
.
get_reordered
()):
until
=
request_args
[
"until"
]
if
isinstance
(
until
,
str
):
until
=
[
until
]
if
until
:
(
primary_until
,)
=
self
.
tok_encode
(
until
[
0
])
try
:
(
primary_until
,)
=
self
.
tok_encode
(
until
[
0
])
except
ValueError
:
if
not
warn_stop_seq
:
print
(
"Warning: a primary stop sequence is multi-token! Will default to EOS token for this tokenizer. Consider using `hf-causal-experimental` for multi-token stop sequence support for the time being."
)
warn_stop_seq
=
True
primary_until
=
self
.
eot_token_id
else
:
primary_until
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment