Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d424f26b
Commit
d424f26b
authored
May 10, 2023
by
Jeffrey Quesnelle
Browse files
fix adaptive batch crash when there are no new requests (e.g. when pulling from cache)
parent
8fc04fe5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
18 deletions
+21
-18
lm_eval/base.py
lm_eval/base.py
+21
-18
No files found.
lm_eval/base.py
View file @
d424f26b
...
@@ -254,25 +254,28 @@ class BaseLM(LM):
...
@@ -254,25 +254,28 @@ class BaseLM(LM):
# automatic (variable) batch size detection for vectorization
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
# pull longest context sample from request
_
,
context_enc
,
continuation_enc
=
re_ord
.
get_reordered
()[
0
]
if
len
(
re_ord
.
get_reordered
())
>
0
:
max_context
=
len
((
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
])
_
,
context_enc
,
continuation_enc
=
re_ord
.
get_reordered
()[
0
]
if
(
self
.
batch_size
==
'auto'
):
max_context
=
len
((
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
])
if
(
self
.
batch_size
==
'auto'
):
if
override_bs
is
None
:
print
(
'Passed argument batch_size = auto. Detecting largest batch size'
)
@
find_executable_batch_size
(
starting_batch_size
=
512
)
# if OOM, then halves batch_size and tries again
def
forward_batch
(
batch_size
):
test_batch
=
torch
.
ones
((
batch_size
,
max_context
),
device
=
self
.
device
).
long
()
for
_
in
range
(
5
):
out
=
F
.
log_softmax
(
self
.
_model_call
(
test_batch
),
dim
=
-
1
).
cpu
()
return
batch_size
batch_size
=
forward_batch
()
print
(
f
"Determined largest batch size:
{
batch_size
}
"
)
adaptive_batch_size
=
batch_size
else
:
if
override_bs
is
None
:
adaptive_batch_size
=
override_bs
print
(
'Passed argument batch_size = auto. Detecting largest batch size'
)
@
find_executable_batch_size
(
starting_batch_size
=
512
)
# if OOM, then halves batch_size and tries again
def
forward_batch
(
batch_size
):
test_batch
=
torch
.
ones
((
batch_size
,
max_context
),
device
=
self
.
device
).
long
()
for
_
in
range
(
5
):
out
=
F
.
log_softmax
(
self
.
_model_call
(
test_batch
),
dim
=
-
1
).
cpu
()
return
batch_size
batch_size
=
forward_batch
()
print
(
f
"Determined largest batch size:
{
batch_size
}
"
)
adaptive_batch_size
=
batch_size
else
:
adaptive_batch_size
=
override_bs
else
:
adaptive_batch_size
=
0
if
override_bs
is
None
else
override_bs
for
chunk
in
utils
.
chunks
(
for
chunk
in
utils
.
chunks
(
tqdm
(
re_ord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
adaptive_batch_size
tqdm
(
re_ord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
adaptive_batch_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment