Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
35dbdc41
Commit
35dbdc41
authored
Jun 18, 2025
by
王敏
Browse files
暂时去掉fused_moe中量化kernel的bug修复
parent
cff5452a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
50 deletions
+21
-50
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+21
-50
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
35dbdc41
...
...
@@ -183,8 +183,7 @@ def fused_moe_kernel_awq(
compute_type
:
tl
.
constexpr
,
has_zp
:
tl
.
constexpr
,
use_int4_w4a16
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
enable_expert_parallel
:
int
,):
use_int8_w8a16
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
...
...
@@ -202,17 +201,6 @@ def fused_moe_kernel_awq(
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
# [block_m]
token_mask
=
offs_token
<
num_valid_tokens
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
if
enable_expert_parallel
:
if
off_experts
==
-
1
:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
)
return
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
# [block_n]
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
# 0, 1, 2, ...... , 127 # # [block_k]
...
...
@@ -220,6 +208,8 @@ def fused_moe_kernel_awq(
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
# [block_m, block_k]
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
if
use_int4_w4a16
:
# [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63]
# [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127]
...
...
@@ -350,8 +340,7 @@ def fused_moe_kernel_gptq_awq(
compute_type
:
tl
.
constexpr
,
has_zp
:
tl
.
constexpr
,
use_int4_w4a16
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
enable_expert_parallel
:
int
,):
use_int8_w8a16
:
tl
.
constexpr
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
...
...
@@ -405,23 +394,14 @@ def fused_moe_kernel_gptq_awq(
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
if
enable_expert_parallel
:
if
off_experts
==
-
1
:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
)
return
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
if
use_int4_w4a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
(
offs_k
[:,
None
]
//
2
)
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
...
...
@@ -561,8 +541,7 @@ def fused_moe_kernel(
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
per_channel_quant
:
tl
.
constexpr
,
enable_expert_parallel
:
int
,
per_channel_quant
:
tl
.
constexpr
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
...
...
@@ -625,13 +604,12 @@ def fused_moe_kernel(
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
).
to
(
tl
.
int64
)
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
if
enable_expert_parallel
:
if
off_experts
==
-
1
:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
...
...
@@ -642,7 +620,7 @@ def fused_moe_kernel(
return
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
.
to
(
tl
.
int64
)
)
%
N
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
...
...
@@ -761,8 +739,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int4_w4a16
:
bool
,
per_channel_quant
:
bool
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
enable_expert_parallel
:
int
=
0
,)
->
None
:
use_nn_moe
:
Optional
[
bool
]
=
False
)
->
None
:
assert
topk_weights
is
not
None
or
not
mul_routed_weight
assert
topk_weights
is
None
or
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
...
@@ -849,7 +826,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
enable_expert_parallel
=
enable_expert_parallel
,
**
config
,
)
else
:
...
...
@@ -888,7 +864,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
enable_expert_parallel
=
enable_expert_parallel
,
**
config
,
)
else
:
...
...
@@ -937,7 +912,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
per_channel_quant
=
per_channel_quant
,
enable_expert_parallel
=
enable_expert_parallel
,
# BLOCK_SIZE_K=BLOCK_SIZE_K,
**
config
,
)
...
...
@@ -1750,7 +1724,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
global_num_experts
,
expert_map
))
enable_expert_parallel
=
(
int
)(
expert_map
is
not
None
)
invoke_fused_moe_kernel
(
qcurr_hidden_states
,
w1
,
intermediate_cache1
,
...
...
@@ -1772,8 +1745,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,
enable_expert_parallel
=
enable_expert_parallel
)
use_nn_moe
=
use_nn_moe
)
if
activation
==
"silu"
:
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
...
...
@@ -1834,8 +1806,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16
=
use_int4_w4a16
,
per_channel_quant
=
per_channel_quant
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,
enable_expert_parallel
=
enable_expert_parallel
)
use_nn_moe
=
use_nn_moe
)
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment