Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
883e47d5
Commit
883e47d5
authored
Sep 05, 2023
by
Casper Hansen
Browse files
Implement batch size
parent
abdc726c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
11 deletions
+23
-11
awq/entry.py
awq/entry.py
+23
-11
No files found.
awq/entry.py
View file @
883e47d5
...
@@ -74,29 +74,35 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
...
@@ -74,29 +74,35 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
print
(
evaluator
.
make_table
(
results
))
print
(
evaluator
.
make_table
(
results
))
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
run_speed
(
model_path
,
quant_file
,
device
,
n_generate
=
128
,
n_context
=
256
):
def
run_speed
(
model_path
,
quant_file
,
device
,
n_generate
=
128
,
n_context
=
256
,
batch_size
=
1
,
disable_fused_layers
=
False
):
def
_timer
(
func
):
def
_timer
(
func
):
start
=
time
.
time
()
start
=
time
.
time
()
out
=
func
()
out
=
func
()
return
out
,
time
.
time
()
-
start
return
out
,
time
.
time
()
-
start
def
_generate
(
model
,
model_out
,
n_generate
):
def
_generate
(
model
,
model_out
,
n_generate
,
batch_size
):
past_key_values
=
model_out
.
past_key_values
past_key_values
=
model_out
.
past_key_values
for
i
in
range
(
n_generate
):
for
i
in
range
(
n_generate
):
logits
=
model_out
.
logits
[
0
,
-
1
,
:]
logits
=
model_out
.
logits
[:,
-
1
,
:]
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
new_tokens
=
[]
token
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
token
=
torch
.
as_tensor
([
token
],
device
=
device
).
unsqueeze
(
0
)
model_out
=
model
(
token
,
use_cache
=
True
,
past_key_values
=
past_key_values
)
for
batch_index
in
range
(
batch_size
):
probs
=
torch
.
softmax
(
logits
[
batch_index
],
dim
=-
1
)
token
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
new_tokens
.
append
(
token
)
tokens
=
torch
.
as_tensor
(
new_tokens
,
device
=
device
).
unsqueeze
(
-
1
)
model_out
=
model
(
tokens
,
use_cache
=
True
,
past_key_values
=
past_key_values
)
def
_warmup
(
device
:
str
):
def
_warmup
(
device
:
str
):
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
device
)
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
device
)
torch
.
mm
(
warm_up
,
warm_up
)
torch
.
mm
(
warm_up
,
warm_up
)
if
quant_file
:
if
quant_file
:
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
True
))
fuse_layers
=
False
if
disable_fused_layers
else
True
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
fuse_layers
))
else
:
else
:
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_pretrained
(
model_path
))
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_pretrained
(
model_path
))
...
@@ -105,13 +111,13 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256):
...
@@ -105,13 +111,13 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256):
# Generate random inputs
# Generate random inputs
n_context
=
n_context
-
n_generate
n_context
=
n_context
-
n_generate
ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
1
,
n_context
)).
cuda
()
ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
batch_size
,
n_context
)).
cuda
()
# Context stage
# Context stage
model_out
,
context_time
=
_timer
(
lambda
:
model
(
ids
,
use_cache
=
True
))
model_out
,
context_time
=
_timer
(
lambda
:
model
(
ids
,
use_cache
=
True
))
# Generation stage
# Generation stage
_
,
generation_time
=
_timer
(
lambda
:
_generate
(
model
,
model_out
,
n_generate
))
_
,
generation_time
=
_timer
(
lambda
:
_generate
(
model
,
model_out
,
n_generate
,
batch_size
))
# Prints
# Prints
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
2
)
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
2
)
...
@@ -164,6 +170,9 @@ if __name__ == '__main__':
...
@@ -164,6 +170,9 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--task_n_shot'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--task_n_shot'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--n_generate'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--n_generate'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--n_context'
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
'--n_context'
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--disable_fused_layers"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Pass '--disable_fused_layers' to disable fused layers"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
args
.
q_group_size
,
"w_bit"
:
args
.
w_bit
}
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
args
.
q_group_size
,
"w_bit"
:
args
.
w_bit
}
...
@@ -176,6 +185,9 @@ if __name__ == '__main__':
...
@@ -176,6 +185,9 @@ if __name__ == '__main__':
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
elif
args
.
entry_type
==
'speed'
:
elif
args
.
entry_type
==
'speed'
:
run_speed
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
n_generate
,
args
.
n_context
)
if
args
.
batch_size
>
1
and
not
args
.
disable_fused_layers
:
raise
Exception
(
'Fused layers only support batch_size=1. Pass --disable_fused_layers to run batch_size>1 (much slower).'
)
run_speed
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
n_generate
,
args
.
n_context
,
args
.
batch_size
,
args
.
disable_fused_layers
)
else
:
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|eval|speed)'
)
raise
Exception
(
'--entry_type must be one of (search|quant|eval|speed)'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment