Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
46750ff9
Commit
46750ff9
authored
Aug 24, 2023
by
Casper Hansen
Browse files
Add speed benchmark
parent
63346c34
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
2 deletions
+60
-2
awq/entry.py
awq/entry.py
+60
-2
No files found.
awq/entry.py
View file @
46750ff9
import
os
import
os
import
time
import
torch
import
torch
import
argparse
import
argparse
from
lm_eval
import
evaluator
from
lm_eval
import
evaluator
...
@@ -72,6 +73,56 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
...
@@ -72,6 +73,56 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
print
(
evaluator
.
make_table
(
results
))
print
(
evaluator
.
make_table
(
results
))
@
torch
.
inference_mode
()
def
run_speed
(
model_path
,
quant_file
,
device
,
n_generate
=
128
,
max_new_tokens
=
256
):
def
_timer
(
func
):
start
=
time
.
time
()
out
=
func
()
return
out
,
time
.
time
()
-
start
def
_generate
(
model
,
model_out
,
n_generate
):
past_key_values
=
model_out
.
past_key_values
for
i
in
range
(
n_generate
):
logits
=
model_out
.
logits
[
0
,
-
1
,
:]
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
token
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
token
=
torch
.
as_tensor
([
token
],
device
=
device
).
unsqueeze
(
0
)
model_out
=
model
(
token
,
use_cache
=
True
,
past_key_values
=
past_key_values
)
def
_warmup
(
device
:
str
):
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
device
)
torch
.
mm
(
warm_up
,
warm_up
)
# Load model
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
True
))
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
_warmup
(
device
)
# Generate random inputs
n_context
=
max_new_tokens
-
n_generate
ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
1
,
n_context
)).
cuda
()
# Context stage
model_out
,
context_time
=
_timer
(
lambda
:
model
(
ids
,
use_cache
=
True
))
# Generation stage
_
,
generation_time
=
_timer
(
lambda
:
_generate
(
model
,
model_out
,
n_generate
))
# Prints
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
2
)
context_tokens_per_second
=
n_context
/
context_time
context_ms_per_token
=
(
context_time
*
1000
)
/
n_context
inference_tokens_per_second
=
n_generate
/
generation_time
inference_ms_per_token
=
(
generation_time
*
1000
)
/
n_generate
print
(
f
"[======] Model summary:
{
model_path
}
[======]"
)
print
(
f
"[*] Load time:
{
load_time
:.
2
f
}
seconds"
)
print
(
f
"[*] Context speed:
{
context_tokens_per_second
:.
2
f
}
tokens/second (
{
context_ms_per_token
:.
2
f
}
ms/token)"
)
print
(
f
"[*] Generation speed:
{
inference_tokens_per_second
:.
2
f
}
tokens/second (
{
inference_ms_per_token
:.
2
f
}
ms/token)"
)
print
(
f
"[*] VRAM:
{
memory_used
:.
2
f
}
MB"
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
"""
"""
- Run AWQ search and save result:
- Run AWQ search and save result:
...
@@ -85,9 +136,12 @@ if __name__ == '__main__':
...
@@ -85,9 +136,12 @@ if __name__ == '__main__':
- Run perplexity unquantized FP16 model:
- Run perplexity unquantized FP16 model:
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
- Run a speedtest to benchmark the quantized model:
python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
"""
"""
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--entry_type'
,
type
=
str
,
help
=
'The type of task to run (search|quant|eval)'
)
parser
.
add_argument
(
'--entry_type'
,
type
=
str
,
help
=
'The type of task to run (search|quant|eval
|speed
)'
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
'Path to hf model'
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
'Path to hf model'
)
parser
.
add_argument
(
'--search_path'
,
type
=
str
,
help
=
'Path to save/load AWQ search results'
)
parser
.
add_argument
(
'--search_path'
,
type
=
str
,
help
=
'Path to save/load AWQ search results'
)
parser
.
add_argument
(
'--quant_path'
,
type
=
str
,
help
=
'Path to save AWQ model to directory'
)
parser
.
add_argument
(
'--quant_path'
,
type
=
str
,
help
=
'Path to save AWQ model to directory'
)
...
@@ -102,6 +156,8 @@ if __name__ == '__main__':
...
@@ -102,6 +156,8 @@ if __name__ == '__main__':
help
=
"Pass '--task_use_pretrained' to use a pretrained model running FP16"
)
help
=
"Pass '--task_use_pretrained' to use a pretrained model running FP16"
)
parser
.
add_argument
(
'--task_batch_size'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--task_batch_size'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--task_n_shot'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--task_n_shot'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--n_generate'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--n_context'
,
type
=
int
,
default
=
256
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
args
.
q_group_size
,
"w_bit"
:
args
.
w_bit
}
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
args
.
q_group_size
,
"w_bit"
:
args
.
w_bit
}
...
@@ -113,5 +169,7 @@ if __name__ == '__main__':
...
@@ -113,5 +169,7 @@ if __name__ == '__main__':
elif
args
.
entry_type
==
'eval'
:
elif
args
.
entry_type
==
'eval'
:
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
elif
args
.
entry_type
==
'speed'
:
run_speed
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
n_generate
,
args
.
n_context
)
else
:
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|eval)'
)
raise
Exception
(
'--entry_type must be one of (search|quant|eval|speed)'
)
\ No newline at end of file
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment