Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
abdc726c
Unverified
Commit
abdc726c
authored
Sep 05, 2023
by
Casper
Committed by
GitHub
Sep 05, 2023
Browse files
Merge pull request #25 from wanzhenchn/main
support speedtest to benchmark FP16 model
parents
637d4abd
4f42f509
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
14 deletions
+20
-14
awq/entry.py
awq/entry.py
+20
-14
No files found.
awq/entry.py
View file @
abdc726c
...
...
@@ -12,7 +12,7 @@ from awq.utils.lm_eval_adaptor import LMEvalAdaptor
def
load_search_result_into_memory
(
model
,
search_path
):
awq_results
=
torch
.
load
(
search_path
,
map_location
=
"cpu"
)
apply_scale
(
model
,
awq_results
[
"scale"
])
apply_clip
(
model
,
awq_results
[
"clip"
])
...
...
@@ -56,7 +56,7 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
else
:
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# Load adapter
...
...
@@ -74,12 +74,12 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
print
(
evaluator
.
make_table
(
results
))
@
torch
.
inference_mode
()
def
run_speed
(
model_path
,
quant_file
,
device
,
n_generate
=
128
,
max_new_tokens
=
256
):
def
run_speed
(
model_path
,
quant_file
,
device
,
n_generate
=
128
,
n_context
=
256
):
def
_timer
(
func
):
start
=
time
.
time
()
out
=
func
()
return
out
,
time
.
time
()
-
start
def
_generate
(
model
,
model_out
,
n_generate
):
past_key_values
=
model_out
.
past_key_values
...
...
@@ -90,18 +90,21 @@ def run_speed(model_path, quant_file, device, n_generate=128, max_new_tokens=256
token
=
torch
.
as_tensor
([
token
],
device
=
device
).
unsqueeze
(
0
)
model_out
=
model
(
token
,
use_cache
=
True
,
past_key_values
=
past_key_values
)
def
_warmup
(
device
:
str
):
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
device
)
torch
.
mm
(
warm_up
,
warm_up
)
# Load model
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
True
))
if
quant_file
:
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
True
))
else
:
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_pretrained
(
model_path
))
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
_warmup
(
device
)
# Generate random inputs
n_context
=
max_new_tokens
-
n_generate
n_context
=
n_context
-
n_generate
ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
1
,
n_context
)).
cuda
()
# Context stage
...
...
@@ -109,7 +112,7 @@ def run_speed(model_path, quant_file, device, n_generate=128, max_new_tokens=256
# Generation stage
_
,
generation_time
=
_timer
(
lambda
:
_generate
(
model
,
model_out
,
n_generate
))
# Prints
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
2
)
context_tokens_per_second
=
n_context
/
context_time
...
...
@@ -138,7 +141,10 @@ if __name__ == '__main__':
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
- Run a speedtest to benchmark the quantized model:
python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt --n_generate 128 --n_context 256
- Run a speedtest to benchmark the unquantized FP16 model:
python -m awq.entry --entry_type speed --model_path lmsys/vicuna-7b-v1.5 --n_generate 128 --n_context 256
"""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--entry_type'
,
type
=
str
,
help
=
'The type of task to run (search|quant|eval|speed)'
)
...
...
@@ -161,15 +167,15 @@ if __name__ == '__main__':
args
=
parser
.
parse_args
()
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
args
.
q_group_size
,
"w_bit"
:
args
.
w_bit
}
if
args
.
entry_type
==
'search'
:
run_search
(
args
.
model_path
,
args
.
search_path
,
quant_config
)
elif
args
.
entry_type
==
'quant'
:
run_quant
(
args
.
model_path
,
args
.
search_path
,
args
.
quant_path
,
quant_config
)
elif
args
.
entry_type
==
'eval'
:
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
elif
args
.
entry_type
==
'speed'
:
run_speed
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
n_generate
,
args
.
n_context
)
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|eval|speed)'
)
\ No newline at end of file
raise
Exception
(
'--entry_type must be one of (search|quant|eval|speed)'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment