Commit 3c7c9ca2 authored by 王敏's avatar 王敏
Browse files

[fix]1.临时修复deepgemm导致dp+ep精度异常问题;2.解决mtp>1强制走piecewise的问题

parent e7dcfb5b
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
import torch import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.forward_context import get_forward_context, is_forward_context_available
...@@ -33,8 +36,10 @@ from vllm.utils.deep_gemm import ( ...@@ -33,8 +36,10 @@ from vllm.utils.deep_gemm import (
from vllm.utils.math_utils import cdiv, round_up from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.import_utils import has_deep_gemm from vllm.utils.import_utils import has_deep_gemm
from vllm.model_executor.layers.activation import SiluAndMul
from lightop import fuse_silu_mul_quant_ep from lightop import fuse_silu_mul_quant_ep
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
if has_deep_gemm(): if has_deep_gemm():
from deep_gemm import m_grouped_w8a8_gemm_nt_masked from deep_gemm import m_grouped_w8a8_gemm_nt_masked
else: else:
...@@ -45,6 +50,175 @@ else: ...@@ -45,6 +50,175 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
# ==============================================
# MOE Grouped GEMM Triton内核 (int8量化 + 专家并行)
# 输入布局:All2All后 -> [E, M, K] / [E, N, K]
# 输出:[E, M, N] 直接写入传入的output张量
# ==============================================
@triton.jit
def moe_grouped_gemm_kernel(
# 指针
A_ptr, B_ptr,
A_scale_ptr, B_scale_ptr,
token_counts_ptr,
output_ptr,
# 维度步长 (Batch/E维度步长, M/Token步长, N/Out通道步长, K/特征步长)
stride_A_E, stride_A_M, stride_A_K,
stride_B_E, stride_B_N, stride_B_K,
stride_A_scale_E, stride_A_scale_M,
stride_B_scale_E, stride_B_scale_N,
stride_out_E, stride_out_M, stride_out_N,
# 固定维度
E: tl.constexpr, # 专家总数
M: tl.constexpr, # 每个专家最大Token数
N: tl.constexpr, # 每个专家输出维度
K: tl.constexpr, # 输入特征维度
# 分块参数 (T自动调优)
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# ===================== 1. 专家ID + 计算坐标 =====================
# 程序ID对应:专家ID(E) + Token分块(M) + 输出分块(N)
pid_e = tl.program_id(0) # 专家维度 (0~E-1)
pid_m = tl.program_id(1) # Token分块维度
pid_n = tl.program_id(2) # 输出分块维度
# 当前专家实际需要计算的Token数量
token_cnt = tl.load(token_counts_ptr + pid_e)
# 超出实际Token数直接退出 (动态Token数)
if pid_m * BLOCK_M >= token_cnt:
return
# ===================== 2. 计算当前分块的内存偏移 =====================
# 输入A [E, M, K]
A_base = A_ptr + pid_e * stride_A_E
# 权重B [E, N, K]
B_base = B_ptr + pid_e * stride_B_E
# Scale
A_scale_base = A_scale_ptr + pid_e * stride_A_scale_E
B_scale_base = B_scale_ptr + pid_e * stride_B_scale_E
# 输出 [E, M, N]
out_base = output_ptr + pid_e * stride_out_E
# 分块坐标
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# 内存索引
a_ptrs = A_base + (offs_m[:, None] * stride_A_M + offs_k[None, :] * stride_A_K)
b_ptrs = B_base + (offs_n[:, None] * stride_B_N + offs_k[None, :] * stride_B_K)
# ===================== 3. 初始化累加器 =====================
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# ===================== 4. K维度循环计算GEMM (int8矩阵乘) =====================
for k in range(0, K, BLOCK_K):
# 加载int8数据 (保持int8精度)
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[None, :] < K - k, other=0.0)
# 矩阵乘累加
acc += tl.dot(a, tl.trans(b)) # B: [N,K] -> 转置为[K,N]
# 指针步进
a_ptrs += BLOCK_K * stride_A_K
b_ptrs += BLOCK_K * stride_B_K
# ===================== 5. int8反量化 (Per-Token + Per-Output Channel) =====================
# 加载当前专家的scale
a_scale = tl.load(A_scale_base + offs_m * stride_A_scale_M) # [BLOCK_M]
b_scale = tl.load(B_scale_base + offs_n * stride_B_scale_N) # [BLOCK_N]
# 反量化:out = (int8_mm) * A_scale * B_scale
result = acc * a_scale[:, None] * b_scale[None, :]
# ===================== 6. 写入输出 [E, M, N] =====================
out_ptrs = out_base + (offs_m[:, None] * stride_out_M + offs_n[None, :] * stride_out_N)
# 掩码:只写有效Token + 有效输出通道
mask_m = offs_m < token_cnt
mask_n = offs_n < N
mask = mask_m[:, None] & mask_n[None, :]
tl.store(out_ptrs, result, mask=mask)
# ==============================================
# 包装函数 (对外调用接口,自动处理步长/启动网格)
# ==============================================
def moe_grouped_gemm(
A: torch.Tensor, # [E, M, K]
B: torch.Tensor, # [E, N, K] int8
A_scale: torch.Tensor, # [E, M, 1]
B_scale: torch.Tensor, # [E, N, 1]
token_counts: torch.Tensor, # [E]
output: torch.Tensor, # [E, M, N] (传入,直接写入)
):
# 维度校验
E, M, K = A.shape
_, N, _ = B.shape
assert B.shape == (E, N, K)
assert A_scale.shape == (E, M, 1)
assert B_scale.shape == (E, N, 1)
assert token_counts.shape == (E,)
assert output.shape == (E, M, N)
# 设备统一
assert A.device == B.device == A_scale.device == B_scale.device == token_counts.device == output.device
assert A.is_cuda
# 自动分块大小 (适配主流GPU)
BLOCK_M = 64
BLOCK_N = 64
BLOCK_K = 64
# 计算网格:[E, ceil(M/BLOCK_M), ceil(N/BLOCK_N)]
grid = (
E,
triton.cdiv(M, BLOCK_M),
triton.cdiv(N, BLOCK_N),
)
# 启动内核
moe_grouped_gemm_kernel[grid](
# 数据指针
A, B,
A_scale, B_scale,
token_counts,
output,
# 步长 (按最后一维连续的张量自动计算)
stride_A_E=A.stride(0), stride_A_M=A.stride(1), stride_A_K=A.stride(2),
stride_B_E=B.stride(0), stride_B_N=B.stride(1), stride_B_K=B.stride(2),
stride_A_scale_E=A_scale.stride(0), stride_A_scale_M=A_scale.stride(1),
stride_B_scale_E=B_scale.stride(0), stride_B_scale_N=B_scale.stride(1),
stride_out_E=output.stride(0), stride_out_M=output.stride(1), stride_out_N=output.stride(2),
# 固定维度
E=E, M=M, N=N, K=K,
# 分块参数
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return output
def native_w8a8_perChannel_batch_matmul(q_a1_all, weight13, qa1_scale_all, w13_scale, output_dtype):
A = q_a1_all.to(torch.float32)
B = weight13.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
C = torch.bmm(A, B.transpose(1,2)) # [E, M, K]
C = qa1_scale_all * C * w13_scale.transpose(1,2) # Broadcast per-column scale
C = C.to(output_dtype)
return C
def scales_shape_stride_dtype( def scales_shape_stride_dtype(
E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT
) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]:
...@@ -297,6 +471,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -297,6 +471,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.N = N self.N = N
self.K = K self.K = K
self.act_fn = SiluAndMul()
@staticmethod @staticmethod
def activation_format() -> mk.FusedMoEActivationFormat: def activation_format() -> mk.FusedMoEActivationFormat:
...@@ -466,20 +641,26 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -466,20 +641,26 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m, expected_m,
) )
elif self.quant_config.use_int8_w8a8: elif self.quant_config.use_int8_w8a8:
m_grouped_w8a8_gemm_nt_masked((a1q, a1q_scale), # m_grouped_w8a8_gemm_nt_masked((a1q, a1q_scale),
(w1, self.w1_scale), # (w1, self.w1_scale),
workspace1, # workspace1,
expert_num_tokens, # expert_num_tokens,
expected_m, # expected_m,
) # )
assert expert_num_tokens is not None # assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens) # a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale), # m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, self.w2_scale), # (w2, self.w2_scale),
output, # output,
expert_num_tokens, # expert_num_tokens,
expected_m) # expected_m)
moe_grouped_gemm(a1q, w1, a1q_scale, self.w1_scale, expert_num_tokens, workspace1)
act_out = self.act_fn(workspace1)
a2q, a2q_scale = per_token_quant_int8(act_out)
moe_grouped_gemm(a2q, w2, a2q_scale, self.w2_scale, expert_num_tokens, output)
else: else:
raise ValueError(f"Unsupported dtype {self.quant_config.quant_dtype}") raise ValueError(f"Unsupported dtype {self.quant_config.quant_dtype}")
...@@ -297,21 +297,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -297,21 +297,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Dispatch # Dispatch
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids) dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
quant_type = 0
if self.use_int8_dispatch:
quant_type = 1
elif self.use_fp8_dispatch:
quant_type = 2
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch( expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
a1, a1,
dispatch_topk_ids, dispatch_topk_ids,
self.max_tokens_per_rank, self.max_tokens_per_rank,
num_experts, num_experts,
use_fp8=self.use_fp8_dispatch or self.use_int8_dispatch, quant_type = quant_type,
use_int8=self.use_int8_dispatch, fp8_round_scale=False,
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=qc_a1_gscale_or_scale)
if qc_a1_gscale_or_scale is not None
else dict()
),
async_finish=False, async_finish=False,
return_recv_hook=True, return_recv_hook=True,
) )
......
...@@ -370,27 +370,28 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -370,27 +370,28 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = [] if not self.use_deepep:
for ii in range(layer.w13_weight.shape[0]): w1_marlin_list = []
if not self.use_deepep: for ii in range(layer.w13_weight.shape[0]):
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii]) if not self.use_deepep:
else: w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii]) else:
w1_marlin_list.append(w1_marlin_in) w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin = torch.stack(w1_marlin_list, dim=0) w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
del w1_marlin_list
w2_marlin_list = [] del w1_marlin_list
for ii in range(layer.w2_weight.shape[0]): w2_marlin_list = []
if not self.use_deepep: for ii in range(layer.w2_weight.shape[0]):
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii]) if not self.use_deepep:
else: w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii]) else:
w2_marlin_list.append(w2_marlin_in) w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin = torch.stack(w2_marlin_list, dim=0) w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False) layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def apply( def apply(
self, self,
......
...@@ -202,7 +202,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig): ...@@ -202,7 +202,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = ( _cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE AttentionCGSupport.UNIFORM_BATCH
) )
reorder_batch_threshold: int = 1 reorder_batch_threshold: int = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment