Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
a4626828
"...text-generation-inference.git" did not exist on "f4a073ae6d2cbcf6ee353b4e27ea90586893fe8b"
Commit
a4626828
authored
Sep 06, 2023
by
Casper Hansen
Browse files
Switch to model.generate()
parent
e71181bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
20 deletions
+21
-20
awq/entry.py
awq/entry.py
+21
-20
No files found.
awq/entry.py
View file @
a4626828
...
@@ -3,11 +3,11 @@ import time
...
@@ -3,11 +3,11 @@ import time
import
torch
import
torch
import
argparse
import
argparse
from
lm_eval
import
evaluator
from
lm_eval
import
evaluator
from
transformers
import
AutoTokenizer
from
awq
import
AutoAWQForCausalLM
from
awq
import
AutoAWQForCausalLM
from
awq.quantize.auto_clip
import
apply_clip
from
awq.quantize.auto_clip
import
apply_clip
from
awq.quantize.auto_scale
import
apply_scale
from
awq.quantize.auto_scale
import
apply_scale
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
transformers
import
AutoTokenizer
,
GenerationConfig
def
load_search_result_into_memory
(
model
,
search_path
):
def
load_search_result_into_memory
(
model
,
search_path
):
...
@@ -80,22 +80,6 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
...
@@ -80,22 +80,6 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
out
=
func
()
out
=
func
()
return
out
,
time
.
time
()
-
start
return
out
,
time
.
time
()
-
start
def
_generate
(
model
,
model_out
,
n_generate
,
batch_size
):
past_key_values
=
model_out
.
past_key_values
for
i
in
range
(
n_generate
):
logits
=
model_out
.
logits
[:,
-
1
,
:]
new_tokens
=
[]
for
batch_index
in
range
(
batch_size
):
probs
=
torch
.
softmax
(
logits
[
batch_index
],
dim
=-
1
)
token
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
new_tokens
.
append
(
token
)
tokens
=
torch
.
as_tensor
(
new_tokens
,
device
=
device
).
unsqueeze
(
-
1
)
model_out
=
model
(
tokens
,
use_cache
=
True
,
past_key_values
=
past_key_values
)
def
_warmup
(
device
:
str
):
def
_warmup
(
device
:
str
):
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
device
)
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
device
)
torch
.
mm
(
warm_up
,
warm_up
)
torch
.
mm
(
warm_up
,
warm_up
)
...
@@ -114,10 +98,27 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
...
@@ -114,10 +98,27 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
batch_size
,
n_context
)).
cuda
()
ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
batch_size
,
n_context
)).
cuda
()
# Context stage
# Context stage
model_out
,
context_time
=
_timer
(
lambda
:
model
(
ids
,
use_cache
=
True
))
_
,
context_time
=
_timer
(
lambda
:
model
.
generate
(
ids
,
generation_config
=
GenerationConfig
(
max_new_tokens
=
0
,
min_new_tokens
=
0
,
use_cache
=
True
)
))
# Generation stage
# Generation stage
_
,
generation_time
=
_timer
(
lambda
:
_generate
(
model
,
model_out
,
n_generate
,
batch_size
))
_
,
generation_time
=
_timer
(
lambda
:
model
.
generate
(
ids
,
generation_config
=
GenerationConfig
(
max_new_tokens
=
n_context
,
min_new_tokens
=
n_context
,
forced_eos_token_id
=-
100
,
pad_token_id
=
tokenizer
.
pad_token_id
,
eos_token_id
=-
100
,
use_cache
=
True
)
))
# Prints
# Prints
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
2
)
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
2
)
...
@@ -126,7 +127,7 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
...
@@ -126,7 +127,7 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
inference_tokens_per_second
=
n_generate
/
generation_time
*
batch_size
inference_tokens_per_second
=
n_generate
/
generation_time
*
batch_size
inference_ms_per_token
=
(
generation_time
*
1000
)
/
n_generate
/
batch_size
inference_ms_per_token
=
(
generation_time
*
1000
)
/
n_generate
/
batch_size
print
(
f
"[=
=====
] Model summary:
{
model_path
}
[=
=====
]"
)
print
(
f
"[=] Model summary:
{
model_path
}
[=]"
)
print
(
f
"[*] Load time:
{
load_time
:.
2
f
}
seconds"
)
print
(
f
"[*] Load time:
{
load_time
:.
2
f
}
seconds"
)
print
(
f
"[*] Context speed:
{
context_tokens_per_second
:.
2
f
}
tokens/second (
{
context_ms_per_token
:.
2
f
}
ms/token)"
)
print
(
f
"[*] Context speed:
{
context_tokens_per_second
:.
2
f
}
tokens/second (
{
context_ms_per_token
:.
2
f
}
ms/token)"
)
print
(
f
"[*] Generation speed:
{
inference_tokens_per_second
:.
2
f
}
tokens/second (
{
inference_ms_per_token
:.
2
f
}
ms/token)"
)
print
(
f
"[*] Generation speed:
{
inference_tokens_per_second
:.
2
f
}
tokens/second (
{
inference_ms_per_token
:.
2
f
}
ms/token)"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment