Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1cb37dab
"vscode:/vscode.git/clone" did not exist on "9556af87d5d5a38128db0d09eeb7f2fe16f16589"
Commit
1cb37dab
authored
Jun 05, 2025
by
yangql
Browse files
新增qwen3 fusemoe优化支持
parent
1a4e4cad
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
215 additions
and
35 deletions
+215
-35
vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=K100_AI,dtype=int4_w4a16.json
...figs/E=128,N=96,device_name=K100_AI,dtype=int4_w4a16.json
+182
-0
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+33
-3
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+0
-32
No files found.
vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=K100_AI,dtype=int4_w4a16.json
0 → 100644
View file @
1cb37dab
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
0
},
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
0
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"256"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
0
},
"6144"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
}
}
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
1cb37dab
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch
import
os
from
vllm.distributed
import
get_tensor_model_parallel_rank
,
get_tp_group
from
vllm.distributed
import
get_tensor_model_parallel_rank
,
get_tp_group
from
vllm.model_executor.layers.fused_moe.layer
import
(
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
...
@@ -175,7 +175,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -175,7 +175,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
def
__init__
(
self
,
quant_config
:
MoeWNA16Config
):
def
__init__
(
self
,
quant_config
:
MoeWNA16Config
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
use_w4a16_moe_sz
=
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
...
@@ -278,6 +278,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
...
@@ -278,6 +278,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer
.
register_parameter
(
key
,
param
)
layer
.
register_parameter
(
key
,
param
)
set_weight_attrs
(
param
,
extra_weight_attrs
)
set_weight_attrs
(
param
,
extra_weight_attrs
)
def
restore_qzeros_tensor
(
self
,
qzeros
,
qscales
):
low_bits
=
qzeros
&
0x0F
high_bits
=
qzeros
>>
4
zeors_tensor
=
torch
.
stack
([
low_bits
,
high_bits
],
dim
=
2
).
view
(
qzeros
.
shape
[
0
],
-
1
,
qzeros
.
shape
[
-
1
])
zeors_int16
=
zeors_tensor
.
to
(
torch
.
int16
)
assert
zeors_int16
.
shape
==
qscales
.
shape
uint16_tensor1
=
zeors_int16
.
view
(
torch
.
uint16
)
uint16_tensor2
=
qscales
.
view
(
torch
.
uint16
)
uint32_tensor1
=
uint16_tensor1
.
to
(
torch
.
int32
)
<<
16
uint32_tensor2
=
uint16_tensor2
.
to
(
torch
.
int32
)
result_tensor
=
uint32_tensor1
+
uint32_tensor2
result_tensor
=
result_tensor
.
view
(
torch
.
uint32
)
result_tensor
=
result_tensor
.
transpose
(
1
,
2
).
contiguous
()
return
result_tensor
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
self
.
use_w4a16_moe_sz
:
sz_tensor_1
=
self
.
restore_qzeros_tensor
(
layer
.
w13_qzeros
,
layer
.
w13_scales
)
sz_tensor_2
=
self
.
restore_qzeros_tensor
(
layer
.
w2_qzeros
,
layer
.
w2_scales
)
layer
.
w13_scales
=
torch
.
nn
.
Parameter
(
sz_tensor_1
,
requires_grad
=
False
)
layer
.
w2_scales
=
torch
.
nn
.
Parameter
(
sz_tensor_2
,
requires_grad
=
False
)
layer
.
w13_qzeros
=
None
layer
.
w2_qzeros
=
None
torch
.
cuda
.
empty_cache
()
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
1cb37dab
...
@@ -896,38 +896,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -896,38 +896,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
if
hasattr
(
self
.
config
,
"quantization_config"
)
and
self
.
config
.
quantization_config
[
"quant_method"
]
==
"awq"
and
not
envs
.
VLLM_USE_TRITON_AWQ
:
lay_key_words
=
[
"self_attn.q_a_proj.qweight"
,
"self_attn.q_b_proj.qweight"
,
"self_attn.kv_b_proj.qweight"
,
"self_attn.kv_a_proj_with_mqa.qweight"
,
"self_attn.o_proj.qweight"
,
"mlp.gate_up_proj.qweight"
,
"mlp.down_proj.qweight"
,
"mlp.shared_experts.gate_up_proj.qweight"
,
"mlp.shared_experts.down_proj.qweight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
# moe_gather_sz
moe_key_words
=
[
"mlp.experts.w13_qweight"
,
"mlp.experts.w2_qweight"
]
moe_combined_words
=
"|"
.
join
(
moe_key_words
)
for
layername
in
loaded_params
:
weight
=
params_dict
[
layername
]
matches
=
re
.
findall
(
combined_words
,
layername
)
if
self
.
use_w4a16_moe_sz
:
matches_moe
=
re
.
findall
(
moe_combined_words
,
layername
)
# sz.shape == s.shape.T
if
matches_moe
:
qzeros
=
params_dict
[
layername
.
replace
(
"qweight"
,
"qzeros"
)]
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
sz_tensor
=
self
.
restore_qzeros_tensor
(
qzeros
,
scales
)
scales
.
data
=
sz_tensor
if
hasattr
(
self
.
config
,
"quantization_config"
)
and
self
.
config
.
quantization_config
[
"quant_method"
]
==
"blockwise_int8"
:
if
hasattr
(
self
.
config
,
"quantization_config"
)
and
self
.
config
.
quantization_config
[
"quant_method"
]
==
"blockwise_int8"
:
lay_key_words
=
[
lay_key_words
=
[
"self_attn.q_a_proj.weight"
,
"self_attn.q_a_proj.weight"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment