Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
083b80ea
Commit
083b80ea
authored
Jan 16, 2025
by
zhuwenwen
Browse files
增加w8a8相关修改
parent
09428eec
Changes
42
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
6 deletions
+37
-6
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+32
-5
vllm/utils.py
vllm/utils.py
+5
-1
No files found.
vllm/model_executor/models/qwen2.py
View file @
083b80ea
...
...
@@ -60,6 +60,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
maybe_prefix
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.utils
import
W8a8GetCacheJSON
logger
=
init_logger
(
__name__
)
...
...
@@ -329,12 +330,13 @@ class Qwen2Model(nn.Module):
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'
0
'
))
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'
1
'
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
...
...
@@ -510,15 +512,40 @@ class Qwen2Model(nn.Module):
"mlp.down_proj.weight"
,
]
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
matches
and
"scale"
not
in
layername
:
weight_data
=
params_dict
[
layername
]
k
=
weight_data
.
shape
[
0
]
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
k
,
-
1
)
n
=
weight_data
.
shape
[
0
]
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if
self
.
w8a8_strategy
!=
1
:
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
weight_shapes
)
<
4
:
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
all_json
.
update
(
configs_dict
)
if
self
.
w8a8_strategy
==
1
:
self
.
tritonsingleton
.
triton_json_dict
.
append
(
all_json
)
#找到的所有config都进行一次warmup
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
n
=
int
(
key
.
split
(
'_'
)[
1
])
k
=
int
(
key
.
split
(
'_'
)[
2
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
best_config
=
value
)
return
loaded_params
...
...
vllm/utils.py
View file @
083b80ea
...
...
@@ -1466,7 +1466,11 @@ class W8a8GetCacheJSON:
def
_initialize
(
self
):
current_folder_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
self
.
triton_json_dir
=
(
os
.
getenv
(
'TRITON_JSON_DIR'
,
current_folder_path
+
'/model_executor/layers/quantization/configs/w8a8'
))
json_folder_path
=
current_folder_path
+
'/../lmslim/configs/w8a8'
if
not
os
.
path
.
exists
(
json_folder_path
):
json_folder_path
=
current_folder_path
+
'/model_executor/layers/quantization/configs/w8a8'
self
.
triton_json_dir
=
(
os
.
getenv
(
'TRITON_JSON_DIR'
,
json_folder_path
))
self
.
triton_json_dict
=
[]
def
getspec_config
(
self
,
configs_dict
,
M
,
N
,
K
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment