Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d42e1706
Commit
d42e1706
authored
Jul 10, 2023
by
Benjamin Fattori
Browse files
initial autobatching commit
parent
7d4e92fa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
10 deletions
+73
-10
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+72
-9
main.py
main.py
+1
-1
No files found.
lm_eval/models/huggingface.py
View file @
d42e1706
...
...
@@ -17,7 +17,7 @@ from lm_eval.api.registry import register_model
from
lm_eval.utils
import
MultiTokenEOSCriteria
,
stop_sequences_criteria
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
,
find_executable_batch_size
from
typing
import
List
,
Optional
,
Union
...
...
@@ -67,7 +67,8 @@ class HFLM(LM):
max_length
:
Optional
[
int
]
=
None
,
device
:
Optional
[
str
]
=
"cuda"
,
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
"auto"
,
batch_size
:
Optional
[
int
]
=
1
,
batch_size
:
Optional
[
Union
[
int
,
str
]]
=
1
,
max_batch_size
:
Optional
[
int
]
=
64
,
low_cpu_mem_usage
:
Optional
[
bool
]
=
True
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
# arguments used for splitting a model across GPUs naively.
...
...
@@ -90,7 +91,7 @@ class HFLM(LM):
assert
isinstance
(
device
,
str
)
assert
isinstance
(
pretrained
,
str
)
assert
isinstance
(
batch_size
,
int
)
assert
isinstance
(
batch_size
,
(
int
,
str
)
)
gpus
=
torch
.
cuda
.
device_count
()
accelerator
=
Accelerator
()
...
...
@@ -223,8 +224,17 @@ class HFLM(LM):
self
.
_max_length
=
max_length
# multithreading and batching
self
.
batch_size_per_gpu
=
batch_size
#TODO: Where to put this
self
.
batch_schedule
=
1
self
.
batch_sizes
=
{}
self
.
max_batch_size
=
max_batch_size
if
str
(
batch_size
).
startswith
(
"auto"
):
batch_size
=
batch_size
.
split
(
":"
)
self
.
batch_size_per_gpu
=
batch_size
[
0
]
self
.
batch_schedule
=
float
(
batch_size
[
1
])
if
len
(
batch_size
)
>
1
else
1
else
:
self
.
batch_size_per_gpu
=
int
(
batch_size
)
# multigpu data-parallel support when launched with accelerate
if
gpus
>
1
:
...
...
@@ -321,6 +331,29 @@ class HFLM(LM):
def
world_size
(
self
):
return
self
.
_world_size
def
_detect_batch_size
(
self
,
requests
=
None
,
pos
=
0
):
if
requests
:
_
,
context_enc
,
continuation_enc
=
requests
[
pos
]
max_length
=
len
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
]
)
else
:
max_length
=
self
.
max_length
# if OOM, then halves batch_size and tries again
@
find_executable_batch_size
(
starting_batch_size
=
self
.
max_batch_size
)
def
forward_batch
(
batch_size
):
test_batch
=
torch
.
ones
((
batch_size
,
max_length
),
device
=
self
.
device
).
long
()
for
_
in
range
(
5
):
_
=
F
.
log_softmax
(
self
.
_model_call
(
test_batch
),
dim
=-
1
).
cpu
()
return
batch_size
batch_size
=
forward_batch
()
utils
.
clear_torch_cache
()
return
batch_size
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
=
None
):
""" """
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
...
...
@@ -459,6 +492,15 @@ class HFLM(LM):
def
loglikelihood_rolling
(
self
,
requests
):
loglikelihoods
=
[]
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
# using rolling window with maximum context
print
(
"Passed argument batch_size = auto. Detecting largest batch size"
)
batch_size
=
self
.
_detect_batch_size
()
print
(
f
"Determined Largest batch size:
{
batch_size
}
"
)
adaptive_batch_size
=
batch_size
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
],
disable
=
(
self
.
rank
!=
0
)):
rolling_token_windows
=
list
(
map
(
...
...
@@ -502,7 +544,7 @@ class HFLM(LM):
return
loglikelihoods
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
,
override_bs
=
None
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
...
...
@@ -516,12 +558,33 @@ class HFLM(LM):
toks
=
x
[
1
]
+
x
[
2
]
return
-
len
(
toks
),
tuple
(
toks
)
# TODO: automatic (variable) batch size detection for vectorization
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
n_reordered_requests
=
len
(
re_ord
.
get_reordered
())
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
def
_batch_scheduler
(
pos
):
sched
=
pos
//
int
(
n_reordered_requests
/
self
.
batch_schedule
)
if
sched
in
self
.
batch_sizes
:
return
self
.
batch_sizes
[
sched
]
print
(
f
"Passed argument batch_size = auto:
{
self
.
batch_schedule
}
. Detecting largest batch size"
)
self
.
batch_sizes
[
sched
]
=
self
.
_detect_batch_size
(
re_ord
.
get_reordered
(),
pos
)
print
(
f
"Determined largest batch size:
{
self
.
batch_sizes
[
sched
]
}
"
)
return
self
.
batch_sizes
[
sched
]
for
chunk
in
utils
.
chunks
(
tqdm
(
re_ord
.
get_reordered
(),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
))),
self
.
batch_size
,
n
=
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
override_bs
if
override_bs
is
not
None
else
0
,
fn
=
_batch_scheduler
if
self
.
batch_size
==
"auto"
and
n_reordered_requests
>
0
and
not
override_bs
else
None
,
):
inps
=
[]
cont_toks_list
=
[]
...
...
main.py
View file @
d42e1706
...
...
@@ -22,7 +22,7 @@ def parse_args():
)
parser
.
add_argument
(
"--config"
,
default
=
None
)
parser
.
add_argument
(
"--num_fewshot"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--batch_size"
,
type
=
str
,
default
=
1
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment