Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
8dbf009b
Unverified
Commit
8dbf009b
authored
Dec 03, 2023
by
Casper
Committed by
GitHub
Dec 03, 2023
Browse files
Benchmark hf generate (#237)
parent
d1112e1c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
83 additions
and
12 deletions
+83
-12
examples/benchmark.py
examples/benchmark.py
+83
-12
No files found.
examples/benchmark.py
View file @
8dbf009b
...
...
@@ -4,14 +4,38 @@ import argparse
import
numpy
as
np
import
pandas
as
pd
from
awq
import
AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
from
torch.cuda
import
OutOfMemoryError
from
awq.models.base
import
BaseAWQForCausalLM
from
transformers
import
AutoTokenizer
,
GenerationConfig
,
LogitsProcessor
,
LogitsProcessorList
class
TimeMeasuringLogitsProcessor
(
LogitsProcessor
):
def
__init__
(
self
):
self
.
token_times
=
[
time
.
time
()]
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
):
"""The logit processor is called after the model forward."""
# cuda runs async operates, so we synchronize for accurate time measurement
torch
.
cuda
.
synchronize
()
# measure time
start_time
=
time
.
time
()
self
.
token_times
.
append
(
start_time
)
return
scores
def
get_prefill_duration
(
self
):
return
self
.
token_times
[
1
]
-
self
.
token_times
[
0
]
def
get_decode_durations
(
self
):
token_times
=
self
.
token_times
[
1
:]
token_durations
=
[
token_times
[
i
+
1
]
-
token_times
[
i
]
for
i
in
range
(
len
(
token_times
)
-
1
)]
return
token_durations
def
warmup
(
model
):
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
next
(
model
.
parameters
()).
device
)
torch
.
mm
(
warm_up
,
warm_up
)
def
generate
(
model
,
input_ids
,
n_generate
):
def
generate
_torch
(
model
,
input_ids
,
n_generate
):
context_time
=
0
generate_time
=
[]
...
...
@@ -39,21 +63,52 @@ def generate(model, input_ids, n_generate):
return
context_time
,
generate_time
def
run_round
(
model_path
,
quant_file
,
n_generate
,
input_ids
,
batch_size
,
no_safetensors
):
print
(
f
" -- Loading model..."
)
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
True
,
max_new_tokens
=
n_generate
,
batch_size
=
batch_size
,
safetensors
=
not
no_safetensors
def
generate_hf
(
model
:
BaseAWQForCausalLM
,
input_ids
,
n_generate
):
generation_config
=
GenerationConfig
(
min_new_tokens
=
n_generate
,
max_new_tokens
=
n_generate
,
use_cache
=
True
,
forced_eos_token_id
=-
100
,
eos_token_id
=-
100
,
)
time_processor
=
TimeMeasuringLogitsProcessor
()
model
.
generate
(
input_ids
,
generation_config
=
generation_config
,
logits_processor
=
LogitsProcessorList
([
time_processor
]),
)
context_time
=
time_processor
.
get_prefill_duration
()
generate_time
=
time_processor
.
get_decode_durations
()
return
context_time
,
generate_time
def
run_round
(
generator
,
model_path
,
quant_file
,
n_generate
,
input_ids
,
batch_size
,
no_safetensors
,
pretrained
):
print
(
f
" -- Loading model..."
)
if
pretrained
:
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
,
safetensors
=
not
no_safetensors
,
device_map
=
"cuda"
,
torch_dtype
=
torch
.
float16
,
)
else
:
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
True
,
max_new_tokens
=
n_generate
,
batch_size
=
batch_size
,
safetensors
=
not
no_safetensors
)
print
(
f
" -- Warming up..."
)
warmup
(
model
)
print
(
f
" -- Generating
{
n_generate
}
tokens,
{
input_ids
.
shape
[
1
]
}
in context..."
)
try
:
context_time
,
generate_time
=
generat
e
(
model
,
input_ids
,
n_generate
)
context_time
,
generate_time
=
generat
or
(
model
,
input_ids
,
n_generate
)
successful_generate
=
True
except
RuntimeError
as
ex
:
if
'cuda out of memory'
in
str
(
ex
).
lower
():
...
...
@@ -77,6 +132,11 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, no_safe
else
:
prefill_tokens_per_second
=
'OOM'
decode_tokens_per_second
=
'OOM'
if
pretrained
:
version
=
"FP16"
else
:
version
=
model
.
quant_config
.
version
return
{
"Batch Size"
:
batch_size
,
...
...
@@ -85,7 +145,7 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, no_safe
"Prefill tokens/s"
:
prefill_tokens_per_second
,
"Decode tokens/s"
:
decode_tokens_per_second
,
"Memory (VRAM)"
:
f
"
{
memory_used
:.
2
f
}
GB (
{
memory_pct
:.
2
f
}
%)"
},
model
.
quant_config
.
version
},
version
def
main
(
args
):
rounds
=
[
...
...
@@ -98,6 +158,13 @@ def main(args):
{
"context"
:
2048
,
"n_generate"
:
2048
},
]
if
args
.
generator
==
"torch"
:
generator
=
generate_torch
elif
args
.
generator
==
"hf"
:
generator
=
generate_hf
else
:
raise
ValueError
(
f
"Unknown generator method passed:
{
args
.
generator
}
"
)
all_stats
=
[]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_path
,
trust_remote_code
=
True
)
...
...
@@ -105,12 +172,14 @@ def main(args):
input_ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
args
.
batch_size
,
settings
[
"context"
])).
cuda
()
stats
,
model_version
=
run_round
(
generator
,
args
.
model_path
,
args
.
quant_file
,
settings
[
"n_generate"
],
input_ids
,
args
.
batch_size
,
args
.
no_safetensors
args
.
no_safetensors
,
args
.
pretrained
)
all_stats
.
append
(
stats
)
...
...
@@ -130,6 +199,8 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--quant_file"
,
type
=
str
,
default
=
""
,
help
=
"weights filename"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size for cache and generation"
)
parser
.
add_argument
(
"--no_safetensors"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Use for disabling safetensors"
)
parser
.
add_argument
(
"--generator"
,
type
=
str
,
default
=
"torch"
,
choices
=
[
"torch"
,
"hf"
],
help
=
"weights filename"
)
parser
.
add_argument
(
"--pretrained"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Measure pretrained model."
)
args
=
parser
.
parse_args
()
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment