Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
f862a118
Unverified
Commit
f862a118
authored
Jun 12, 2023
by
Stella Biderman
Committed by
GitHub
Jun 12, 2023
Browse files
Merge pull request #572 from gakada/perf
Add --max_batch_size and --batch_size auto:N
parents
23dcc12e
8cec82b2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
59 deletions
+68
-59
lm_eval/base.py
lm_eval/base.py
+41
-37
lm_eval/evaluator.py
lm_eval/evaluator.py
+6
-2
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+7
-16
lm_eval/utils.py
lm_eval/utils.py
+9
-3
main.py
main.py
+5
-1
No files found.
lm_eval/base.py
View file @
f862a118
...
...
@@ -119,6 +119,12 @@ class LM(abc.ABC):
class
BaseLM
(
LM
):
def
__init__
(
self
):
super
().
__init__
()
self
.
batch_schedule
=
1
self
.
batch_sizes
=
{}
self
.
max_batch_size
=
512
@
property
@
abstractmethod
def
eot_token_id
(
self
):
...
...
@@ -167,6 +173,26 @@ class BaseLM(LM):
"""
pass
def
_detect_batch_size
(
self
,
requests
=
None
,
pos
=
0
):
if
requests
:
_
,
context_enc
,
continuation_enc
=
requests
[
pos
]
max_length
=
len
((
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
])
else
:
max_length
=
self
.
max_length
# if OOM, then halves batch_size and tries again
@
find_executable_batch_size
(
starting_batch_size
=
self
.
max_batch_size
)
def
forward_batch
(
batch_size
):
test_batch
=
torch
.
ones
((
batch_size
,
max_length
),
device
=
self
.
device
).
long
()
for
_
in
range
(
5
):
_
=
F
.
log_softmax
(
self
.
_model_call
(
test_batch
),
dim
=-
1
).
cpu
()
return
batch_size
batch_size
=
forward_batch
()
utils
.
clear_torch_cache
()
return
batch_size
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
# TODO: enforce this somehow
...
...
@@ -202,19 +228,7 @@ class BaseLM(LM):
if
self
.
batch_size
==
"auto"
:
# using rolling window with maximum context
print
(
"Passed argument batch_size = auto. Detecting largest batch size"
)
@
find_executable_batch_size
(
starting_batch_size
=
512
)
# if OOM, then halves batch_size and tries again
def
forward_batch
(
batch_size
):
test_batch
=
torch
.
ones
(
(
batch_size
,
self
.
max_length
),
device
=
self
.
device
).
long
()
for
_
in
range
(
5
):
_
=
F
.
log_softmax
(
self
.
_model_call
(
test_batch
),
dim
=-
1
).
cpu
()
return
batch_size
batch_size
=
forward_batch
()
batch_size
=
self
.
_detect_batch_size
()
print
(
f
"Determined Largest batch size:
{
batch_size
}
"
)
adaptive_batch_size
=
batch_size
...
...
@@ -267,34 +281,24 @@ class BaseLM(LM):
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
reordered_requests
=
re_ord
.
get_reordered
()
n_reordered_requests
=
len
(
reordered_requests
)
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
if
len
(
re_ord
.
get_reordered
())
>
0
:
_
,
context_enc
,
continuation_enc
=
re_ord
.
get_reordered
()[
0
]
max_context
=
len
((
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
])
if
(
self
.
batch_size
==
'auto'
):
if
override_bs
is
None
:
print
(
'Passed argument batch_size = auto. Detecting largest batch size'
)
@
find_executable_batch_size
(
starting_batch_size
=
512
)
# if OOM, then halves batch_size and tries again
def
forward_batch
(
batch_size
):
test_batch
=
torch
.
ones
((
batch_size
,
max_context
),
device
=
self
.
device
).
long
()
for
_
in
range
(
5
):
out
=
F
.
log_softmax
(
self
.
_model_call
(
test_batch
),
dim
=
-
1
).
cpu
()
return
batch_size
batch_size
=
forward_batch
()
print
(
f
"Determined largest batch size:
{
batch_size
}
"
)
adaptive_batch_size
=
batch_size
else
:
adaptive_batch_size
=
override_bs
else
:
adaptive_batch_size
=
0
if
override_bs
is
None
else
override_bs
def
_batch_scheduler
(
pos
):
sched
=
pos
//
int
(
n_reordered_requests
/
self
.
batch_schedule
)
if
sched
in
self
.
batch_sizes
:
return
self
.
batch_sizes
[
sched
]
print
(
f
"Passed argument batch_size = auto:
{
self
.
batch_schedule
}
. Detecting largest batch size"
)
self
.
batch_sizes
[
sched
]
=
self
.
_detect_batch_size
(
reordered_requests
,
pos
)
print
(
f
"Determined largest batch size:
{
self
.
batch_sizes
[
sched
]
}
"
)
return
self
.
batch_sizes
[
sched
]
for
chunk
in
utils
.
chunks
(
tqdm
(
re_ord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
adaptive_batch_size
,
tqdm
(
reordered_requests
,
disable
=
disable_tqdm
),
n
=
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
override_bs
if
override_bs
is
not
None
else
0
,
fn
=
_batch_scheduler
if
self
.
batch_size
==
"auto"
and
n_reordered_requests
>
0
else
None
,
):
inps
=
[]
cont_toks_list
=
[]
...
...
lm_eval/evaluator.py
View file @
f862a118
...
...
@@ -16,6 +16,7 @@ def simple_evaluate(
tasks
=
[],
num_fewshot
=
0
,
batch_size
=
None
,
max_batch_size
=
None
,
device
=
None
,
no_cache
=
False
,
limit
=
None
,
...
...
@@ -37,8 +38,10 @@ def simple_evaluate(
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int
Number of examples in few-shot context
:param batch_size: int, optional
:param batch_size: int
or str
, optional
Batch size for model
:param max_batch_size: int, optional
Maximal batch size to try with automatic batch size detection
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool
...
...
@@ -67,7 +70,7 @@ def simple_evaluate(
if
model_args
is
None
:
model_args
=
""
lm
=
lm_eval
.
models
.
get_model
(
model
).
create_from_arg_string
(
model_args
,
{
"batch_size"
:
batch_size
,
"device"
:
device
}
model_args
,
{
"batch_size"
:
batch_size
,
"max_batch_size"
:
max_batch_size
,
"device"
:
device
}
)
else
:
assert
isinstance
(
model
,
lm_eval
.
base
.
LM
)
...
...
@@ -106,6 +109,7 @@ def simple_evaluate(
"model_args"
:
model_args
,
"num_fewshot"
:
num_fewshot
,
"batch_size"
:
batch_size
,
"batch_sizes"
:
list
(
lm
.
batch_sizes
.
values
()),
"device"
:
device
,
"no_cache"
:
no_cache
,
"limit"
:
limit
,
...
...
lm_eval/models/huggingface.py
View file @
f862a118
...
...
@@ -9,7 +9,6 @@ from typing import List, Mapping, NewType, Optional, Tuple, Union
from
tqdm
import
tqdm
from
transformers
import
BatchEncoding
from
accelerate
import
find_executable_batch_size
from
lm_eval
import
utils
from
lm_eval.base
import
BaseLM
...
...
@@ -76,6 +75,7 @@ class HuggingFaceAutoLM(BaseLM):
subfolder
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
"main"
,
batch_size
:
Optional
[
Union
[
int
,
str
]]
=
1
,
max_batch_size
:
Optional
[
int
]
=
512
,
max_gen_toks
:
Optional
[
int
]
=
256
,
max_length
:
Optional
[
int
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
,
...
...
@@ -172,10 +172,13 @@ class HuggingFaceAutoLM(BaseLM):
),
"Evaluating causal models with `add_special_tokens=True` is currently not supported."
# setup for automatic batch size detection
if
batch_size
==
"auto"
:
self
.
_batch_size
=
batch_size
if
str
(
batch_size
).
startswith
(
"auto"
):
batch_size
=
batch_size
.
split
(
":"
)
self
.
_batch_size
=
batch_size
[
0
]
self
.
batch_schedule
=
float
(
batch_size
[
1
])
if
len
(
batch_size
)
>
1
else
1
else
:
self
.
_batch_size
=
int
(
batch_size
)
self
.
max_batch_size
=
max_batch_size
self
.
_max_gen_toks
=
max_gen_toks
self
.
_max_length
=
max_length
...
...
@@ -411,19 +414,7 @@ class HuggingFaceAutoLM(BaseLM):
if
self
.
batch_size
==
"auto"
:
# using rolling window with maximum context
print
(
"Passed argument batch_size = auto. Detecting largest batch size"
)
@
find_executable_batch_size
(
starting_batch_size
=
512
)
# if OOM, then halves batch_size and tries again
def
forward_batch
(
batch_size
):
test_batch
=
torch
.
ones
(
(
batch_size
,
self
.
max_length
),
device
=
self
.
device
).
long
()
for
_
in
range
(
5
):
_
=
F
.
log_softmax
(
self
.
_model_call
(
test_batch
),
dim
=-
1
).
cpu
()
return
batch_size
batch_size
=
forward_batch
()
batch_size
=
self
.
_detect_batch_size
()
print
(
f
"Determined Largest batch size:
{
batch_size
}
"
)
adaptive_batch_size
=
batch_size
...
...
lm_eval/utils.py
View file @
f862a118
...
...
@@ -8,6 +8,7 @@ import sys
import
fnmatch
from
typing
import
List
,
Union
import
gc
import
torch
from
omegaconf
import
OmegaConf
...
...
@@ -64,11 +65,11 @@ def join_iters(iters):
yield
from
iter
def
chunks
(
iter
,
n
):
def
chunks
(
iter
,
n
=
0
,
fn
=
None
):
arr
=
[]
for
x
in
iter
:
for
i
,
x
in
enumerate
(
iter
)
:
arr
.
append
(
x
)
if
len
(
arr
)
==
n
:
if
len
(
arr
)
==
(
fn
(
i
)
if
fn
else
n
)
:
yield
arr
arr
=
[]
...
...
@@ -283,3 +284,8 @@ def run_task_tests(task_list: List[str]):
raise
ValueError
(
f
"Not all tests for the specified tasks (
{
task_list
}
) ran successfully! Error code:
{
pytest_return_val
}
"
)
def
clear_torch_cache
():
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
main.py
View file @
f862a118
...
...
@@ -16,6 +16,8 @@ def parse_args():
parser
.
add_argument
(
"--provide_description"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num_fewshot"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch_size"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
None
,
help
=
"Maximal batch size to try with --batch_size auto"
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--output_path"
,
default
=
None
)
parser
.
add_argument
(
"--limit"
,
type
=
float
,
default
=
None
,
...
...
@@ -60,6 +62,7 @@ def main():
tasks
=
task_names
,
num_fewshot
=
args
.
num_fewshot
,
batch_size
=
args
.
batch_size
,
max_batch_size
=
args
.
max_batch_size
,
device
=
args
.
device
,
no_cache
=
args
.
no_cache
,
limit
=
args
.
limit
,
...
...
@@ -78,9 +81,10 @@ def main():
with
open
(
args
.
output_path
,
"w"
)
as
f
:
f
.
write
(
dumped
)
batch_sizes
=
","
.
join
(
map
(
str
,
results
[
"config"
][
"batch_sizes"
]))
print
(
f
"
{
args
.
model
}
(
{
args
.
model_args
}
), limit:
{
args
.
limit
}
, provide_description:
{
args
.
provide_description
}
, "
f
"num_fewshot:
{
args
.
num_fewshot
}
, batch_size:
{
args
.
batch_size
}
"
f
"num_fewshot:
{
args
.
num_fewshot
}
, batch_size:
{
args
.
batch_size
}
{
f
' (
{
batch_sizes
}
)
' if batch_sizes else ''
}
"
)
print
(
evaluator
.
make_table
(
results
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment