Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
4f42f509
You need to sign in or sign up before continuing.
Commit
4f42f509
authored
Sep 05, 2023
by
Casper Hansen
Browse files
Revert to n_context arg
parent
bbbd525e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
9 deletions
+6
-9
awq/entry.py
awq/entry.py
+6
-9
No files found.
awq/entry.py
View file @
4f42f509
...
@@ -74,7 +74,7 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
...
@@ -74,7 +74,7 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
print
(
evaluator
.
make_table
(
results
))
print
(
evaluator
.
make_table
(
results
))
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_speed
(
model_path
,
quant_file
,
device
,
n_generate
=
128
,
max_seq_len
=
256
):
def
run_speed
(
model_path
,
quant_file
,
device
,
n_generate
=
128
,
n_context
=
256
):
def
_timer
(
func
):
def
_timer
(
func
):
start
=
time
.
time
()
start
=
time
.
time
()
out
=
func
()
out
=
func
()
...
@@ -95,18 +95,16 @@ def run_speed(model_path, quant_file, device, n_generate=128, max_seq_len=256):
...
@@ -95,18 +95,16 @@ def run_speed(model_path, quant_file, device, n_generate=128, max_seq_len=256):
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
device
)
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
device
)
torch
.
mm
(
warm_up
,
warm_up
)
torch
.
mm
(
warm_up
,
warm_up
)
# Load model
if
quant_file
:
if
quant_file
:
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
True
))
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
True
))
else
:
else
:
# fp16 model
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_pretrained
(
model_path
))
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_pretrained
(
model_path
))
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
_warmup
(
device
)
_warmup
(
device
)
# Generate random inputs
# Generate random inputs
n_context
=
max_seq_len
-
n_generate
n_context
=
n_context
-
n_generate
ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
1
,
n_context
)).
cuda
()
ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
1
,
n_context
)).
cuda
()
# Context stage
# Context stage
...
@@ -143,11 +141,10 @@ if __name__ == '__main__':
...
@@ -143,11 +141,10 @@ if __name__ == '__main__':
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
- Run a speedtest to benchmark the quantized model:
- Run a speedtest to benchmark the quantized model:
python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
\
python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt --n_generate 128 --n_context 256
--n_generate 128 --max_seq_len 256
- Run a speedtest to benchmark the unquantized FP16 model:
- Run a speedtest to benchmark the unquantized FP16 model:
python -m awq.entry --entry_type speed --model_path lmsys/vicuna-7b-v1.5 --n_generate 128 --
max_seq_len
256
python -m awq.entry --entry_type speed --model_path lmsys/vicuna-7b-v1.5 --n_generate 128 --
n_context
256
"""
"""
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--entry_type'
,
type
=
str
,
help
=
'The type of task to run (search|quant|eval|speed)'
)
parser
.
add_argument
(
'--entry_type'
,
type
=
str
,
help
=
'The type of task to run (search|quant|eval|speed)'
)
...
@@ -166,7 +163,7 @@ if __name__ == '__main__':
...
@@ -166,7 +163,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--task_batch_size'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--task_batch_size'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--task_n_shot'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--task_n_shot'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--n_generate'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--n_generate'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--
max_seq_len
'
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
'--
n_context
'
,
type
=
int
,
default
=
256
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
args
.
q_group_size
,
"w_bit"
:
args
.
w_bit
}
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
args
.
q_group_size
,
"w_bit"
:
args
.
w_bit
}
...
@@ -179,6 +176,6 @@ if __name__ == '__main__':
...
@@ -179,6 +176,6 @@ if __name__ == '__main__':
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
elif
args
.
entry_type
==
'speed'
:
elif
args
.
entry_type
==
'speed'
:
run_speed
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
n_generate
,
args
.
max_seq_len
)
run_speed
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
n_generate
,
args
.
n_context
)
else
:
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|eval|speed)'
)
raise
Exception
(
'--entry_type must be one of (search|quant|eval|speed)'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment