Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99ffef47
Commit
99ffef47
authored
Jun 12, 2025
by
zhuwenwen
Browse files
修复qwen3-moe的awq配置导致fp16加载错误,修复dpsk-moe手写算子首字耗时增加问题
parent
550a1e5e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
221 additions
and
192 deletions
+221
-192
vllm/model_executor/layers/quantization/moe_wna16.py
vllm/model_executor/layers/quantization/moe_wna16.py
+2
-2
vllm/model_executor/layers/quantization/utils/fused_moe_cuda.py
...odel_executor/layers/quantization/utils/fused_moe_cuda.py
+211
-182
vllm/model_executor/models/qwen3_moe.py
vllm/model_executor/models/qwen3_moe.py
+8
-8
No files found.
vllm/model_executor/layers/quantization/moe_wna16.py
View file @
99ffef47
...
...
@@ -354,7 +354,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
has_zp
=
self
.
quant_config
.
has_zp
if
self
.
use_w4a16_cuda
:
m
=
topk_ids
.
shape
[
0
]
if
m
<=
64
:
if
m
<=
512
:
return
fused_experts_cuda
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
...
...
vllm/model_executor/layers/quantization/utils/fused_moe_cuda.py
View file @
99ffef47
...
...
@@ -21,136 +21,194 @@ from grouped_gemm import moe_gemm_w4a16
from
grouped_gemm.ops
import
permute
as
permute_topK
,
unpermute
as
unpermute_topK
import
torch.nn.functional
as
F
logger
=
init_logger
(
__name__
)
device_name
=
current_platform
.
get_device_name
()
def
config_cuda
(
M
):
bw_gemm1_mode_dict
=
{
1
:
83
,
2
:
77
,
3
:
32
,
4
:
38
,
5
:
38
,
6
:
87
,
7
:
82
,
8
:
42
,
9
:
83
,
10
:
42
,
11
:
42
,
12
:
87
,
13
:
42
,
14
:
38
,
15
:
42
,
16
:
42
,
17
:
42
,
18
:
87
,
19
:
87
,
20
:
83
,
21
:
83
,
22
:
83
,
23
:
83
,
24
:
27
,
25
:
42
,
26
:
83
,
27
:
38
,
28
:
42
,
29
:
42
,
30
:
38
,
31
:
42
,
32
:
38
k100ai_gemm1_m_to_mode_dict
=
{
1
:
'M16N16K256NN1NW8B240'
,
2
:
'M16N16K256NN1NW16B360'
,
3
:
'M16N16K256NN1NW16B360'
,
4
:
'M16N16K256NN1NW16B360'
,
5
:
'M16N16K256NN1NW16B360'
,
6
:
'M16N16K256NN1NW16B360'
,
7
:
'M16N16K256NN1NW16B360'
,
8
:
'M16N16K256NN1NW16B240'
,
9
:
'M16N16K256NN1NW16B360'
,
10
:
'M16N16K256NN1NW16B360'
,
11
:
'M16N16K256NN1NW16B360'
,
12
:
'M16N16K256NN1NW16B360'
,
13
:
'M16N16K256NN1NW16B360'
,
14
:
'M16N16K256NN1NW16B360'
,
15
:
'M16N16K256NN1NW16B360'
,
16
:
'M16N16K256NN1NW16B240'
,
17
:
'M16N16K256NN1NW16B360'
,
18
:
'M16N16K256NN1NW16B360'
,
19
:
'M16N16K256NN1NW16B360'
,
20
:
'M16N16K256NN1NW16B360'
,
21
:
'M16N16K256NN1NW16B360'
,
22
:
'M16N16K256NN1NW16B360'
,
23
:
'M16N16K256NN1NW16B360'
,
24
:
'M16N16K256NN1NW16B360'
,
25
:
'M16N16K256NN1NW16B360'
,
26
:
'M16N16K256NN1NW16B360'
,
27
:
'M16N16K256NN1NW16B240'
,
28
:
'M16N16K256NN1NW16B360'
,
29
:
'M16N16K256NN1NW16B360'
,
30
:
'M16N16K256NN1NW16B360'
,
31
:
'M16N16K256NN1NW16B240'
,
32
:
'M16N16K256NN1NW16B360'
,
64
:
'M16N16K256NN1NW16B360'
,
128
:
'M16N16K256NN1NW16B240'
,
256
:
'M16N16K256NN1NW16B360'
,
512
:
'M16N16K128NN1NW8B120'
,
768
:
'M16N32K128NN1NW16B100'
,
1024
:
'M16N32K128NN1NW16B120'
,
}
bw_gemm2_mode_dict
=
{
1
:
23
,
2
:
88
,
3
:
74
,
4
:
39
,
5
:
43
,
6
:
88
,
7
:
88
,
8
:
89
,
9
:
73
,
10
:
88
,
11
:
88
,
12
:
88
,
13
:
88
,
14
:
88
,
15
:
88
,
16
:
88
,
17
:
88
,
18
:
43
,
19
:
88
,
20
:
43
,
21
:
43
,
22
:
43
,
23
:
88
,
24
:
88
,
25
:
88
,
26
:
88
,
27
:
88
,
28
:
88
,
29
:
88
,
30
:
43
,
31
:
88
,
32
:
88
k100ai_gemm2_m_to_mode_dict
=
{
1
:
'M16N32K256NN8NW1B240'
,
2
:
'M16N32K256NN8NW1B360'
,
3
:
'M16N32K256NN8NW1B360'
,
4
:
'M16N32K256NN4NW1B360'
,
5
:
'M16N32K256NN4NW1B360'
,
6
:
'M16N32K256NN4NW1B360'
,
7
:
'M16N32K256NN4NW1B360'
,
8
:
'M16N32K256NN8NW1B360'
,
9
:
'M16N32K256NN8NW1B240'
,
10
:
'M16N32K256NN8NW1B240'
,
11
:
'M16N32K256NN8NW1B240'
,
12
:
'M16N32K256NN8NW1B240'
,
13
:
'M16N32K256NN4NW1B360'
,
14
:
'M16N32K256NN16NW1B360'
,
15
:
'M16N32K256NN16NW1B360'
,
16
:
'M16N32K256NN16NW1B360'
,
17
:
'M16N32K256NN8NW1B240'
,
18
:
'M16N32K256NN8NW1B240'
,
19
:
'M16N32K256NN16NW1B360'
,
20
:
'M16N32K256NN16NW1B360'
,
21
:
'M16N32K256NN16NW1B360'
,
22
:
'M16N32K256NN16NW1B360'
,
23
:
'M16N32K256NN16NW1B360'
,
24
:
'M16N32K256NN16NW1B240'
,
25
:
'M16N32K256NN16NW1B360'
,
26
:
'M16N32K256NN16NW1B360'
,
27
:
'M16N32K256NN16NW1B360'
,
28
:
'M16N32K256NN16NW1B360'
,
29
:
'M16N64K256NN4NW1B240'
,
30
:
'M16N32K256NN16NW1B360'
,
31
:
'M16N32K256NN16NW1B360'
,
32
:
'M16N32K256NN16NW1B240'
,
64
:
'M16N32K256NN16NW1B360'
,
128
:
'M16N64K256NN4NW1B240'
,
256
:
'M16N32K256NN16NW1B360'
,
512
:
'M16N64K256NN8NW1B120'
,
768
:
'M16N64K256NN16NW1B360'
,
1024
:
'M16N64K256NN16NW1B360'
,
}
k100ai_gemm1_mode_dict
=
{
1
:
79
,
2
:
34
,
3
:
34
,
4
:
34
,
6
:
34
,
8
:
34
,
16
:
34
,
24
:
34
,
32
:
34
,
bw_gemm1_m_to_mode_dict
=
{
1
:
'M16N16K256NN1NW8B360'
,
2
:
'M16N16K256NN1NW4B360'
,
3
:
'M16N32K256NN1NW8B240'
,
4
:
'M16N32K256NN1NW4B360'
,
5
:
'M16N64K256NN1NW4B240'
,
6
:
'M16N32K256NN1NW8B240'
,
7
:
'M16N32K256NN1NW8B360'
,
8
:
'M16N64K256NN1NW4B360'
,
9
:
'M16N64K256NN1NW4B240'
,
10
:
'M16N32K256NN1NW8B240'
,
11
:
'M16N64K256NN1NW4B240'
,
12
:
'M16N64K256NN1NW4B360'
,
13
:
'M16N32K256NN1NW8B240'
,
14
:
'M16N32K256NN1NW8B240'
,
15
:
'M16N32K256NN1NW8B240'
,
16
:
'M16N64K256NN1NW4B360'
,
17
:
'M16N32K256NN1NW8B240'
,
18
:
'M16N64K256NN1NW4B240'
,
19
:
'M16N32K256NN1NW8B240'
,
20
:
'M16N32K256NN1NW8B240'
,
21
:
'M16N32K256NN1NW8B240'
,
22
:
'M16N32K256NN1NW8B240'
,
23
:
'M16N32K256NN1NW8B240'
,
24
:
'M16N64K256NN1NW4B240'
,
25
:
'M16N32K256NN1NW8B240'
,
26
:
'M16N32K256NN1NW8B240'
,
27
:
'M16N32K256NN1NW8B240'
,
28
:
'M16N32K256NN1NW8B240'
,
29
:
'M16N64K256NN1NW4B240'
,
30
:
'M16N64K256NN1NW4B240'
,
31
:
'M16N32K256NN1NW8B240'
,
32
:
'M16N64K256NN1NW4B240'
,
64
:
'M16N32K256NN1NW8B240'
,
128
:
'M16N64K256NN1NW4B240'
,
256
:
'M16N64K256NN1NW4B240'
,
512
:
'M16N64K256NN1NW4B240'
,
768
:
'M16N64K256NN1NW4B240'
,
1024
:
'M16N64K256NN1NW4B240'
,
}
k100ai_gemm2_mode_dict
=
{
1
:
64
,
2
:
33
,
3
:
33
,
4
:
37
,
5
:
37
,
6
:
33
,
7
:
33
,
8
:
37
,
9
:
37
,
10
:
37
,
11
:
37
,
12
:
37
,
13
:
37
,
14
:
38
,
15
:
38
,
16
:
72
,
17
:
72
,
18
:
72
,
19
:
72
,
20
:
72
,
21
:
72
,
22
:
72
,
23
:
72
,
24
:
39
,
25
:
39
,
26
:
39
,
27
:
39
,
28
:
39
,
29
:
39
,
30
:
39
,
31
:
39
,
32
:
39
,
bw_gemm2_m_to_mode_dict
=
{
1
:
'M16N32K128NN8NW1B240'
,
2
:
'M16N64K256NN8NW1B240'
,
3
:
'M16N64K256NN4NW1B360'
,
4
:
'M16N64K256NN16NW1B240'
,
5
:
'M16N64K256NN8NW1B240'
,
6
:
'M16N64K256NN8NW1B240'
,
7
:
'M16N64K256NN16NW1B240'
,
8
:
'M16N64K256NN8NW1B240'
,
9
:
'M16N64K256NN16NW1B360'
,
10
:
'M16N64K256NN8NW1B240'
,
11
:
'M16N64K256NN16NW1B360'
,
12
:
'M16N64K256NN8NW1B240'
,
13
:
'M16N64K256NN16NW1B240'
,
14
:
'M16N64K256NN16NW1B360'
,
15
:
'M16N64K256NN16NW1B240'
,
16
:
'M16N64K256NN16NW1B240'
,
17
:
'M16N64K256NN8NW1B240'
,
18
:
'M16N64K256NN8NW1B240'
,
19
:
'M16N64K256NN16NW1B240'
,
20
:
'M16N64K256NN8NW1B240'
,
21
:
'M16N64K256NN16NW1B240'
,
22
:
'M16N64K256NN16NW1B360'
,
23
:
'M16N64K256NN16NW1B360'
,
24
:
'M16N64K256NN16NW1B240'
,
25
:
'M16N64K256NN8NW1B240'
,
26
:
'M16N64K256NN16NW1B240'
,
27
:
'M16N64K256NN16NW1B240'
,
28
:
'M16N64K256NN16NW1B240'
,
29
:
'M16N64K256NN16NW1B240'
,
30
:
'M16N64K256NN16NW1B240'
,
31
:
'M16N64K256NN8NW1B240'
,
32
:
'M16N64K256NN16NW1B240'
,
64
:
'M16N64K256NN16NW1B240'
,
128
:
'M16N64K256NN16NW1B240'
,
256
:
'M16N64K256NN16NW1B240'
,
512
:
'M16N64K256NN16NW1B240'
,
768
:
'M16N64K256NN16NW1B240'
,
1024
:
'M16N64K256NN16NW1B240'
,
}
device_name
=
device_name
=
current_platform
.
get_device_name
()
if
"BW"
in
device_name
:
gemm1_mode_dict
=
bw_gemm1_mode_dict
gemm2_mode_dict
=
bw_gemm2_mode_dict
reference_points
=
[
32
,
64
,
128
,
256
,
512
,
1024
]
NearestM
=
-
1
if
M
<=
32
:
NearestM
=
M
else
:
gemm1_mode_dict
=
k100ai_gemm1_mode_dict
gemm2_mode_dict
=
k100ai_gemm2_mode_dict
NearestM
=
min
(
reference_points
,
key
=
lambda
x
:
abs
(
x
-
M
))
mode_1
=
gemm1_mode_dict
.
get
(
M
,
gemm1_mode_dict
[
32
])
mode_2
=
gemm2_mode_dict
.
get
(
M
,
gemm2_mode_dict
[
32
])
if
device_name
==
"K100_AI"
:
mode_1
=
k100ai_gemm1_m_to_mode_dict
.
get
(
M
,
k100ai_gemm1_m_to_mode_dict
[
NearestM
])
mode_2
=
k100ai_gemm2_m_to_mode_dict
.
get
(
M
,
k100ai_gemm2_m_to_mode_dict
[
NearestM
])
else
:
mode_1
=
bw_gemm1_m_to_mode_dict
.
get
(
M
,
k100ai_gemm1_m_to_mode_dict
[
NearestM
])
mode_2
=
bw_gemm2_m_to_mode_dict
.
get
(
M
,
k100ai_gemm2_m_to_mode_dict
[
NearestM
])
return
mode_1
,
mode_2
def
fused_experts_cuda
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
@@ -218,18 +276,19 @@ def fused_experts_impl_cuda(hidden_states: torch.Tensor,
E
,
N
,
_
=
w1
.
shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
32768
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
M
=
num_tokens
topk
=
topk_ids
.
shape
[
1
]
mode_1
,
mode_2
=
config_cuda
(
M
)
# config = get_config_func(M)
intermediate_cache1
=
torch
.
empty
((
M
,
topk
_ids
.
shape
[
1
]
,
N
),
intermediate_cache1
=
torch
.
empty
((
M
,
topk
,
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache2
=
torch
.
empty
((
M
*
topk
_ids
.
shape
[
1
]
,
N
//
2
),
intermediate_cache2
=
torch
.
empty
((
M
*
topk
,
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache3
=
torch
.
empty
((
M
,
topk
_ids
.
shape
[
1
]
,
w2
.
shape
[
1
]),
intermediate_cache3
=
torch
.
empty
((
M
,
topk
,
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
...
...
@@ -247,63 +306,33 @@ def fused_experts_impl_cuda(hidden_states: torch.Tensor,
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
for
chunk
in
range
((
num_tokens
//
CHUNK_SIZE
)
+
1
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
*
topk_ids
.
shape
[
1
]]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
# config = get_config_func(tokens_in_chunk)
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
16
,
E
,
expert_map
,
curr_hidden_states
.
shape
[
0
]))
mode_1
,
mode_2
=
config_cuda
(
M
)
expert_ids
=
expert_ids
[:
num_tokens_post_padded
//
16
]
moe_gemm_w4a16
.
gemm1_w4a16
(
sorted_token_ids
.
to
(
torch
.
uint16
),
# sorted_token_ids.to(torch.uint16)
curr_hidden_states
,
# hidden_states
moe_align_block_size
(
topk_ids
,
16
,
E
,
expert_map
,
hidden_states
.
shape
[
0
]))
moe_gemm_w4a16
.
gemm1_w4a16
(
sorted_token_ids
,
# sorted_token_ids.to(torch.uint16)
hidden_states
,
# hidden_states
w1
,
# w1
intermediate_cache1
,
# gemm1_out
num_tokens_post_padded
,
# 实际专家数
expert_ids
,
# expert_id_vec
w1_scale
,
# scale_zero
64
,
# group_size
topk
=
topk
_ids
.
shape
[
1
]
,
# topk
topk
=
topk
,
# topk
mode
=
mode_1
)
# mode=gemm1_mode
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
moe_gemm_w4a16
.
gemm2_w4a16
(
sorted_token_ids
.
to
(
torch
.
uint16
),
# sorted_token_ids.to(torch.uint16)
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
# return intermediate_cache2
moe_gemm_w4a16
.
gemm2_w4a16
(
sorted_token_ids
,
# sorted_token_ids.to(torch.uint16)
intermediate_cache2
,
# hidden_states
w2
,
# w2
intermediate_cache3
,
# gemm2_out
num_tokens_post_padded
,
expert_ids
,
# expert_id_vec
w2_scale
,
# scale_zero
curr_
topk_weights
,
# topk_weights
topk_weights
,
# topk_weights
64
,
# group_size
topk
=
topk
_ids
.
shape
[
1
]
,
# topk
topk
=
topk
,
# topk
mode
=
mode_2
)
# mode=gemm2_mode
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
return
out_hidden_states
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
)
return
out_hidden_states
\ No newline at end of file
vllm/model_executor/models/qwen3_moe.py
View file @
99ffef47
...
...
@@ -332,9 +332,13 @@ class Qwen3MoeModel(nn.Module):
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
config
=
config
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
# if self.config.quantization_config["bits"] == 4:
#
os.environ['LLAMA_NN'] = '0'
#
os.environ['LM_NN'] = '0'
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
...
...
@@ -352,10 +356,6 @@ class Qwen3MoeModel(nn.Module):
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment