Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9a877197
Unverified
Commit
9a877197
authored
May 02, 2023
by
Stella Biderman
Committed by
GitHub
May 02, 2023
Browse files
Merge pull request #394 from fattorib/auto-batching
single GPU automatic batching logic
parents
fc4428dc
d6ceced5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
92 additions
and
19 deletions
+92
-19
lm_eval/base.py
lm_eval/base.py
+44
-5
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+16
-9
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+30
-4
main.py
main.py
+1
-1
setup.py
setup.py
+1
-0
No files found.
lm_eval/base.py
View file @
9a877197
...
...
@@ -11,6 +11,7 @@ from sqlitedict import SqliteDict
from
tqdm
import
tqdm
import
torch
import
torch.nn.functional
as
F
from
accelerate
import
find_executable_batch_size
from
lm_eval.metrics
import
mean
,
weighted_perplexity
,
weighted_mean
,
bits_per_byte
from
lm_eval
import
utils
...
...
@@ -186,7 +187,22 @@ class BaseLM(LM):
def
loglikelihood_rolling
(
self
,
requests
):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
# automatic batch size detection for vectorization
adaptive_batch_size
=
None
if
self
.
batch_size
==
'auto'
:
# using rolling window with maximum context
print
(
'Passed argument batch_size = auto. Detecting largest batch size'
)
@
find_executable_batch_size
(
starting_batch_size
=
512
)
# if OOM, then halves batch_size and tries again
def
forward_batch
(
batch_size
):
test_batch
=
torch
.
ones
((
batch_size
,
self
.
max_length
),
device
=
self
.
device
).
long
()
for
_
in
range
(
5
):
out
=
F
.
log_softmax
(
self
.
_model_call
(
test_batch
),
dim
=
-
1
).
cpu
()
return
batch_size
batch_size
=
forward_batch
()
print
(
f
"Determined Largest batch size:
{
batch_size
}
"
)
adaptive_batch_size
=
batch_size
loglikelihoods
=
[]
for
(
string
,)
in
tqdm
(
requests
):
...
...
@@ -207,7 +223,7 @@ class BaseLM(LM):
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
,
disable_tqdm
=
True
rolling_token_windows
,
disable_tqdm
=
True
,
override_bs
=
adaptive_batch_size
)
# discard is_greedy
...
...
@@ -218,7 +234,7 @@ class BaseLM(LM):
return
loglikelihoods
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
,
override_bs
=
None
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
...
...
@@ -233,10 +249,33 @@ class BaseLM(LM):
toks
=
x
[
1
]
+
x
[
2
]
return
-
len
(
toks
),
tuple
(
toks
)
# TODO: automatic (variable) batch size detection for vectorization
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
_
,
context_enc
,
continuation_enc
=
re_ord
.
get_reordered
()[
0
]
max_context
=
len
((
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
])
if
(
self
.
batch_size
==
'auto'
):
if
override_bs
is
None
:
print
(
'Passed argument batch_size = auto. Detecting largest batch size'
)
@
find_executable_batch_size
(
starting_batch_size
=
512
)
# if OOM, then halves batch_size and tries again
def
forward_batch
(
batch_size
):
test_batch
=
torch
.
ones
((
batch_size
,
max_context
),
device
=
self
.
device
).
long
()
for
_
in
range
(
5
):
out
=
F
.
log_softmax
(
self
.
_model_call
(
test_batch
),
dim
=
-
1
).
cpu
()
return
batch_size
batch_size
=
forward_batch
()
print
(
f
"Determined largest batch size:
{
batch_size
}
"
)
adaptive_batch_size
=
batch_size
else
:
adaptive_batch_size
=
override_bs
for
chunk
in
utils
.
chunks
(
tqdm
(
re_ord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
tqdm
(
re_ord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
adaptive_batch_size
):
inps
=
[]
cont_toks_list
=
[]
...
...
lm_eval/models/gpt2.py
View file @
9a877197
...
...
@@ -3,7 +3,6 @@ import transformers
from
typing
import
Optional
from
lm_eval.base
import
BaseLM
class
HFLM
(
BaseLM
):
def
__init__
(
self
,
...
...
@@ -21,7 +20,7 @@ class HFLM(BaseLM):
assert
isinstance
(
device
,
str
)
assert
isinstance
(
pretrained
,
str
)
assert
isinstance
(
batch_size
,
int
)
assert
isinstance
(
batch_size
,
(
int
,
str
)
)
device_list
=
set
([
"cuda"
,
"cpu"
]
+
[
f
'cuda:
{
i
}
'
for
i
in
range
(
torch
.
cuda
.
device_count
())])
if
device
and
device
in
device_list
:
...
...
@@ -56,13 +55,21 @@ class HFLM(BaseLM):
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
# multithreading and batching
self
.
batch_size_per_gpu
=
batch_size
# todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
if
isinstance
(
self
.
tokenizer
,
(
transformers
.
GPT2Tokenizer
,
transformers
.
GPT2TokenizerFast
)
):
assert
self
.
tokenizer
.
encode
(
"hello
\n\n
hello"
)
==
[
31373
,
198
,
198
,
31373
,
],
self
.
tokenizer
.
encode
(
"hello
\n\n
hello"
)
# setup for automatic batch size detection
if
batch_size
==
'auto'
:
self
.
batch_size_per_gpu
=
batch_size
else
:
self
.
batch_size_per_gpu
=
int
(
batch_size
)
@
property
def
eot_token_id
(
self
):
...
...
lm_eval/models/huggingface.py
View file @
9a877197
...
...
@@ -7,6 +7,7 @@ from typing import List, Mapping, NewType, Optional, Tuple, Union
from
tqdm
import
tqdm
from
transformers
import
BatchEncoding
from
accelerate
import
find_executable_batch_size
from
lm_eval
import
utils
from
lm_eval.base
import
BaseLM
...
...
@@ -71,7 +72,7 @@ class HuggingFaceAutoLM(BaseLM):
tokenizer
:
Optional
[
str
]
=
None
,
subfolder
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
"main"
,
batch_size
:
Optional
[
int
]
=
1
,
batch_size
:
Optional
[
Union
[
int
,
str
]
]
=
1
,
max_gen_toks
:
Optional
[
int
]
=
256
,
max_length
:
Optional
[
int
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
,
...
...
@@ -143,7 +144,7 @@ class HuggingFaceAutoLM(BaseLM):
assert
isinstance
(
pretrained
,
str
)
assert
isinstance
(
device
,
str
)
assert
isinstance
(
batch_size
,
int
)
assert
isinstance
(
batch_size
,
(
int
,
str
)
)
if
(
add_special_tokens
is
not
None
and
self
.
AUTO_MODEL_CLASS
is
transformers
.
AutoModelForCausalLM
...
...
@@ -157,7 +158,12 @@ class HuggingFaceAutoLM(BaseLM):
not
add_special_tokens
),
"Evaluating causal models with `add_special_tokens=True` is currently not supported."
self
.
_batch_size
=
batch_size
# TODO: Adaptive batch size
# setup for automatic batch size detection
if
batch_size
==
'auto'
:
self
.
_batch_size
=
batch_size
else
:
self
.
_batch_size
=
int
(
batch_size
)
self
.
_max_gen_toks
=
max_gen_toks
self
.
_max_length
=
max_length
self
.
_config
=
self
.
AUTO_CONFIG_CLASS
.
from_pretrained
(
...
...
@@ -366,10 +372,30 @@ class HuggingFaceAutoLM(BaseLM):
tokens
=
self
.
tok_encode
(
x
[
0
])
return
len
(
tokens
),
x
[
0
]
results
=
[]
reorder
=
utils
.
Reorderer
(
requests
,
_collate
)
_
,
context_enc
,
continuation_enc
=
reorder
.
get_reordered
()[
0
]
max_context
=
len
((
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
])
adaptive_batch_size
=
None
if
self
.
batch_size
==
'auto'
:
# using rolling window with maximum context
print
(
'Passed argument batch_size = auto. Detecting largest batch size'
)
@
find_executable_batch_size
(
starting_batch_size
=
512
)
# if OOM, then halves batch_size and tries again
def
forward_batch
(
batch_size
):
test_batch
=
torch
.
ones
((
batch_size
,
max_context
),
device
=
self
.
device
).
long
()
for
_
in
range
(
5
):
out
=
F
.
log_softmax
(
self
.
_model_call
(
test_batch
),
dim
=
-
1
).
cpu
()
return
batch_size
batch_size
=
forward_batch
()
print
(
f
"Determined Largest batch size:
{
batch_size
}
"
)
adaptive_batch_size
=
batch_size
for
chunk
in
utils
.
chunks
(
tqdm
(
reorder
.
get_reordered
(),
disable
=
False
),
self
.
batch_size
tqdm
(
reorder
.
get_reordered
(),
disable
=
False
),
self
.
batch_size
if
self
.
batch_size
!=
"auto"
else
adaptive_batch_size
):
context
=
[
c
[
0
]
for
c
in
chunk
]
request_args
=
chunk
[
0
][
1
]
...
...
main.py
View file @
9a877197
...
...
@@ -32,7 +32,7 @@ def parse_args():
parser
.
add_argument
(
"--tasks"
,
default
=
None
,
choices
=
MultiChoice
(
tasks
.
ALL_TASKS
))
parser
.
add_argument
(
"--provide_description"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num_fewshot"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--batch_size"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--output_path"
,
default
=
None
)
parser
.
add_argument
(
"--limit"
,
type
=
int
,
default
=
None
)
...
...
setup.py
View file @
9a877197
...
...
@@ -38,6 +38,7 @@ setuptools.setup(
"tqdm-multiprocess"
,
"transformers>=4.1"
,
"zstandard"
,
"accelerate>=0.17.1"
],
extras_require
=
{
"dev"
:
[
"black"
,
"flake8"
,
"pre-commit"
,
"pytest"
,
"pytest-cov"
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment