Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3c7c9ca2
"vscode:/vscode.git/clone" did not exist on "101f1481f9c4e3e108d30ce2f8715ee89288992b"
Commit
3c7c9ca2
authored
Apr 11, 2026
by
王敏
Browse files
[fix]1.临时修复deepgemm导致dp+ep精度异常问题;2.解决mtp>1强制走piecewise的问题
parent
e7dcfb5b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
227 additions
and
47 deletions
+227
-47
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+196
-15
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+8
-10
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+22
-21
vllm/v1/attention/backends/mla/indexer.py
vllm/v1/attention/backends/mla/indexer.py
+1
-1
No files found.
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
3c7c9ca2
...
@@ -3,6 +3,9 @@
...
@@ -3,6 +3,9 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.forward_context
import
get_forward_context
,
is_forward_context_available
from
vllm.forward_context
import
get_forward_context
,
is_forward_context_available
...
@@ -33,8 +36,10 @@ from vllm.utils.deep_gemm import (
...
@@ -33,8 +36,10 @@ from vllm.utils.deep_gemm import (
from
vllm.utils.math_utils
import
cdiv
,
round_up
from
vllm.utils.math_utils
import
cdiv
,
round_up
from
vllm.utils.import_utils
import
has_deep_gemm
from
vllm.utils.import_utils
import
has_deep_gemm
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
lightop
import
fuse_silu_mul_quant_ep
from
lightop
import
fuse_silu_mul_quant_ep
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
if
has_deep_gemm
():
if
has_deep_gemm
():
from
deep_gemm
import
m_grouped_w8a8_gemm_nt_masked
from
deep_gemm
import
m_grouped_w8a8_gemm_nt_masked
else
:
else
:
...
@@ -45,6 +50,175 @@ else:
...
@@ -45,6 +50,175 @@ else:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# ==============================================
# MOE Grouped GEMM Triton内核 (int8量化 + 专家并行)
# 输入布局:All2All后 -> [E, M, K] / [E, N, K]
# 输出:[E, M, N] 直接写入传入的output张量
# ==============================================
@
triton
.
jit
def
moe_grouped_gemm_kernel
(
# 指针
A_ptr
,
B_ptr
,
A_scale_ptr
,
B_scale_ptr
,
token_counts_ptr
,
output_ptr
,
# 维度步长 (Batch/E维度步长, M/Token步长, N/Out通道步长, K/特征步长)
stride_A_E
,
stride_A_M
,
stride_A_K
,
stride_B_E
,
stride_B_N
,
stride_B_K
,
stride_A_scale_E
,
stride_A_scale_M
,
stride_B_scale_E
,
stride_B_scale_N
,
stride_out_E
,
stride_out_M
,
stride_out_N
,
# 固定维度
E
:
tl
.
constexpr
,
# 专家总数
M
:
tl
.
constexpr
,
# 每个专家最大Token数
N
:
tl
.
constexpr
,
# 每个专家输出维度
K
:
tl
.
constexpr
,
# 输入特征维度
# 分块参数 (T自动调优)
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
):
# ===================== 1. 专家ID + 计算坐标 =====================
# 程序ID对应:专家ID(E) + Token分块(M) + 输出分块(N)
pid_e
=
tl
.
program_id
(
0
)
# 专家维度 (0~E-1)
pid_m
=
tl
.
program_id
(
1
)
# Token分块维度
pid_n
=
tl
.
program_id
(
2
)
# 输出分块维度
# 当前专家实际需要计算的Token数量
token_cnt
=
tl
.
load
(
token_counts_ptr
+
pid_e
)
# 超出实际Token数直接退出 (动态Token数)
if
pid_m
*
BLOCK_M
>=
token_cnt
:
return
# ===================== 2. 计算当前分块的内存偏移 =====================
# 输入A [E, M, K]
A_base
=
A_ptr
+
pid_e
*
stride_A_E
# 权重B [E, N, K]
B_base
=
B_ptr
+
pid_e
*
stride_B_E
# Scale
A_scale_base
=
A_scale_ptr
+
pid_e
*
stride_A_scale_E
B_scale_base
=
B_scale_ptr
+
pid_e
*
stride_B_scale_E
# 输出 [E, M, N]
out_base
=
output_ptr
+
pid_e
*
stride_out_E
# 分块坐标
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
# 内存索引
a_ptrs
=
A_base
+
(
offs_m
[:,
None
]
*
stride_A_M
+
offs_k
[
None
,
:]
*
stride_A_K
)
b_ptrs
=
B_base
+
(
offs_n
[:,
None
]
*
stride_B_N
+
offs_k
[
None
,
:]
*
stride_B_K
)
# ===================== 3. 初始化累加器 =====================
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
# ===================== 4. K维度循环计算GEMM (int8矩阵乘) =====================
for
k
in
range
(
0
,
K
,
BLOCK_K
):
# 加载int8数据 (保持int8精度)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
,
other
=
0.0
)
# 矩阵乘累加
acc
+=
tl
.
dot
(
a
,
tl
.
trans
(
b
))
# B: [N,K] -> 转置为[K,N]
# 指针步进
a_ptrs
+=
BLOCK_K
*
stride_A_K
b_ptrs
+=
BLOCK_K
*
stride_B_K
# ===================== 5. int8反量化 (Per-Token + Per-Output Channel) =====================
# 加载当前专家的scale
a_scale
=
tl
.
load
(
A_scale_base
+
offs_m
*
stride_A_scale_M
)
# [BLOCK_M]
b_scale
=
tl
.
load
(
B_scale_base
+
offs_n
*
stride_B_scale_N
)
# [BLOCK_N]
# 反量化:out = (int8_mm) * A_scale * B_scale
result
=
acc
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
# ===================== 6. 写入输出 [E, M, N] =====================
out_ptrs
=
out_base
+
(
offs_m
[:,
None
]
*
stride_out_M
+
offs_n
[
None
,
:]
*
stride_out_N
)
# 掩码:只写有效Token + 有效输出通道
mask_m
=
offs_m
<
token_cnt
mask_n
=
offs_n
<
N
mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
tl
.
store
(
out_ptrs
,
result
,
mask
=
mask
)
# ==============================================
# 包装函数 (对外调用接口,自动处理步长/启动网格)
# ==============================================
def
moe_grouped_gemm
(
A
:
torch
.
Tensor
,
# [E, M, K]
B
:
torch
.
Tensor
,
# [E, N, K] int8
A_scale
:
torch
.
Tensor
,
# [E, M, 1]
B_scale
:
torch
.
Tensor
,
# [E, N, 1]
token_counts
:
torch
.
Tensor
,
# [E]
output
:
torch
.
Tensor
,
# [E, M, N] (传入,直接写入)
):
# 维度校验
E
,
M
,
K
=
A
.
shape
_
,
N
,
_
=
B
.
shape
assert
B
.
shape
==
(
E
,
N
,
K
)
assert
A_scale
.
shape
==
(
E
,
M
,
1
)
assert
B_scale
.
shape
==
(
E
,
N
,
1
)
assert
token_counts
.
shape
==
(
E
,)
assert
output
.
shape
==
(
E
,
M
,
N
)
# 设备统一
assert
A
.
device
==
B
.
device
==
A_scale
.
device
==
B_scale
.
device
==
token_counts
.
device
==
output
.
device
assert
A
.
is_cuda
# 自动分块大小 (适配主流GPU)
BLOCK_M
=
64
BLOCK_N
=
64
BLOCK_K
=
64
# 计算网格:[E, ceil(M/BLOCK_M), ceil(N/BLOCK_N)]
grid
=
(
E
,
triton
.
cdiv
(
M
,
BLOCK_M
),
triton
.
cdiv
(
N
,
BLOCK_N
),
)
# 启动内核
moe_grouped_gemm_kernel
[
grid
](
# 数据指针
A
,
B
,
A_scale
,
B_scale
,
token_counts
,
output
,
# 步长 (按最后一维连续的张量自动计算)
stride_A_E
=
A
.
stride
(
0
),
stride_A_M
=
A
.
stride
(
1
),
stride_A_K
=
A
.
stride
(
2
),
stride_B_E
=
B
.
stride
(
0
),
stride_B_N
=
B
.
stride
(
1
),
stride_B_K
=
B
.
stride
(
2
),
stride_A_scale_E
=
A_scale
.
stride
(
0
),
stride_A_scale_M
=
A_scale
.
stride
(
1
),
stride_B_scale_E
=
B_scale
.
stride
(
0
),
stride_B_scale_N
=
B_scale
.
stride
(
1
),
stride_out_E
=
output
.
stride
(
0
),
stride_out_M
=
output
.
stride
(
1
),
stride_out_N
=
output
.
stride
(
2
),
# 固定维度
E
=
E
,
M
=
M
,
N
=
N
,
K
=
K
,
# 分块参数
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
BLOCK_K
=
BLOCK_K
,
)
return
output
def
native_w8a8_perChannel_batch_matmul
(
q_a1_all
,
weight13
,
qa1_scale_all
,
w13_scale
,
output_dtype
):
A
=
q_a1_all
.
to
(
torch
.
float32
)
B
=
weight13
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
],
"Dimension mismatch"
C
=
torch
.
bmm
(
A
,
B
.
transpose
(
1
,
2
))
# [E, M, K]
C
=
qa1_scale_all
*
C
*
w13_scale
.
transpose
(
1
,
2
)
# Broadcast per-column scale
C
=
C
.
to
(
output_dtype
)
return
C
def
scales_shape_stride_dtype
(
def
scales_shape_stride_dtype
(
E
:
int
,
T
:
int
,
G
:
int
,
quant_scale_fmt
:
DeepGemmQuantScaleFMT
E
:
int
,
T
:
int
,
G
:
int
,
quant_scale_fmt
:
DeepGemmQuantScaleFMT
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
...
@@ -297,6 +471,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -297,6 +471,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self
.
N
=
N
self
.
N
=
N
self
.
K
=
K
self
.
K
=
K
self
.
act_fn
=
SiluAndMul
()
@
staticmethod
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
...
@@ -466,20 +641,26 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -466,20 +641,26 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m
,
expected_m
,
)
)
elif
self
.
quant_config
.
use_int8_w8a8
:
elif
self
.
quant_config
.
use_int8_w8a8
:
m_grouped_w8a8_gemm_nt_masked
((
a1q
,
a1q_scale
),
# m_grouped_w8a8_gemm_nt_masked((a1q, a1q_scale),
(
w1
,
self
.
w1_scale
),
# (w1, self.w1_scale),
workspace1
,
# workspace1,
expert_num_tokens
,
# expert_num_tokens,
expected_m
,
# expected_m,
)
# )
assert
expert_num_tokens
is
not
None
# assert expert_num_tokens is not None
# a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
# m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
# (w2, self.w2_scale),
# output,
# expert_num_tokens,
# expected_m)
moe_grouped_gemm
(
a1q
,
w1
,
a1q_scale
,
self
.
w1_scale
,
expert_num_tokens
,
workspace1
)
act_out
=
self
.
act_fn
(
workspace1
)
a2q
,
a2q_scale
=
per_token_quant_int8
(
act_out
)
moe_grouped_gemm
(
a2q
,
w2
,
a2q_scale
,
self
.
w2_scale
,
expert_num_tokens
,
output
)
a2q
,
a2q_scale
=
fuse_silu_mul_quant_ep
(
workspace1
,
expert_num_tokens
)
m_grouped_w8a8_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
self
.
w2_scale
),
output
,
expert_num_tokens
,
expected_m
)
else
:
else
:
raise
ValueError
(
f
"Unsupported dtype
{
self
.
quant_config
.
quant_dtype
}
"
)
raise
ValueError
(
f
"Unsupported dtype
{
self
.
quant_config
.
quant_dtype
}
"
)
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
3c7c9ca2
...
@@ -297,21 +297,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -297,21 +297,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Dispatch
# Dispatch
dispatch_topk_ids
=
self
.
_map_global_to_physical_ids
(
topk_ids
)
dispatch_topk_ids
=
self
.
_map_global_to_physical_ids
(
topk_ids
)
quant_type
=
0
if
self
.
use_int8_dispatch
:
quant_type
=
1
elif
self
.
use_fp8_dispatch
:
quant_type
=
2
expert_x
,
expert_num_tokens
,
handle
,
_
,
hook
=
self
.
buffer
.
low_latency_dispatch
(
expert_x
,
expert_num_tokens
,
handle
,
_
,
hook
=
self
.
buffer
.
low_latency_dispatch
(
a1
,
a1
,
dispatch_topk_ids
,
dispatch_topk_ids
,
self
.
max_tokens_per_rank
,
self
.
max_tokens_per_rank
,
num_experts
,
num_experts
,
use_fp8
=
self
.
use_fp8_dispatch
or
self
.
use_int8_dispatch
,
quant_type
=
quant_type
,
use_int8
=
self
.
use_int8_dispatch
,
fp8_round_scale
=
False
,
round_scale
=
self
.
use_ue8m0_dispatch
,
use_ue8m0
=
self
.
use_ue8m0_dispatch
,
**
(
dict
(
use_nvfp4
=
True
)
if
use_nvfp4
else
dict
()),
**
(
dict
(
x_global_scale
=
qc_a1_gscale_or_scale
)
if
qc_a1_gscale_or_scale
is
not
None
else
dict
()
),
async_finish
=
False
,
async_finish
=
False
,
return_recv_hook
=
True
,
return_recv_hook
=
True
,
)
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
3c7c9ca2
...
@@ -370,6 +370,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -370,6 +370,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer
.
w2_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
not
self
.
use_deepep
:
w1_marlin_list
=
[]
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
if
not
self
.
use_deepep
:
if
not
self
.
use_deepep
:
...
...
vllm/v1/attention/backends/mla/indexer.py
View file @
3c7c9ca2
...
@@ -202,7 +202,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
...
@@ -202,7 +202,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
class
DeepseekV32IndexerMetadataBuilder
(
AttentionMetadataBuilder
):
class
DeepseekV32IndexerMetadataBuilder
(
AttentionMetadataBuilder
):
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
(
_cudagraph_support
:
ClassVar
[
AttentionCGSupport
]
=
(
AttentionCGSupport
.
UNIFORM_
SINGLE_TOKEN_DECODE
AttentionCGSupport
.
UNIFORM_
BATCH
)
)
reorder_batch_threshold
:
int
=
1
reorder_batch_threshold
:
int
=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment