Commit 4fca01b8 authored by wangmin6's avatar wangmin6 Committed by zhangzbb
Browse files

[Perf]优化EP低延迟模式下调度,消除调度空泡

parent 8c96d505
......@@ -340,7 +340,8 @@ def set_forward_context(
forward_start_time = time.perf_counter()
dp_metadata: DPMetadata | None = None
if vllm_config.parallel_config.data_parallel_size > 1 and (
if vllm_config.parallel_config.data_parallel_size > 1 and \
envs.VLLM_ALL2ALL_BACKEND != "deepep_low_latency" and (
attn_metadata is not None or num_tokens is not None
):
# If num_tokens_across_dp hasn't already been initialized, then
......
......@@ -3,6 +3,9 @@
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.forward_context import get_forward_context, is_forward_context_available
......@@ -33,8 +36,10 @@ from vllm.utils.deep_gemm import (
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.import_utils import has_deep_gemm
from vllm.model_executor.layers.activation import SiluAndMul
from lightop import fuse_silu_mul_quant_ep
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
if has_deep_gemm():
from deepgemm import m_grouped_w8a8_gemm_nt_masked
else:
......@@ -45,6 +50,161 @@ else:
logger = init_logger(__name__)
# ==============================================
# MOE Grouped GEMM Triton内核 (int8量化 + 专家并行)
# 输入布局:All2All后 -> [E, M, K] / [E, N, K]
# 输出:[E, M, N] 直接写入传入的output张量
# ==============================================
@triton.jit
def moe_grouped_gemm_kernel(
# 指针
A_ptr, B_ptr,
A_scale_ptr, B_scale_ptr,
token_counts_ptr,
output_ptr,
# 维度步长 (Batch/E维度步长, M/Token步长, N/Out通道步长, K/特征步长)
stride_A_E, stride_A_M, stride_A_K,
stride_B_E, stride_B_N, stride_B_K,
stride_A_scale_E, stride_A_scale_M,
stride_B_scale_E, stride_B_scale_N,
stride_out_E, stride_out_M, stride_out_N,
# 固定维度
E: tl.constexpr, # 专家总数
M: tl.constexpr, # 每个专家最大Token数
N: tl.constexpr, # 每个专家输出维度
K: tl.constexpr, # 输入特征维度
# 分块参数 (T自动调优)
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# ===================== 1. 专家ID + 计算坐标 =====================
# 程序ID对应:专家ID(E) + Token分块(M) + 输出分块(N)
pid_e = tl.program_id(0) # 专家维度 (0~E-1)
pid_m = tl.program_id(1) # Token分块维度
pid_n = tl.program_id(2) # 输出分块维度
# 当前专家实际需要计算的Token数量
token_cnt = tl.load(token_counts_ptr + pid_e)
# 超出实际Token数直接退出 (动态Token数)
if pid_m * BLOCK_M >= token_cnt:
return
# ===================== 2. 计算当前分块的内存偏移 =====================
# 输入A [E, M, K]
A_base = A_ptr + pid_e * stride_A_E
# 权重B [E, N, K]
B_base = B_ptr + pid_e * stride_B_E
# Scale
A_scale_base = A_scale_ptr + pid_e * stride_A_scale_E
B_scale_base = B_scale_ptr + pid_e * stride_B_scale_E
# 输出 [E, M, N]
out_base = output_ptr + pid_e * stride_out_E
# 分块坐标
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# 内存索引
a_ptrs = A_base + (offs_m[:, None] * stride_A_M + offs_k[None, :] * stride_A_K)
b_ptrs = B_base + (offs_n[:, None] * stride_B_N + offs_k[None, :] * stride_B_K)
# ===================== 3. 初始化累加器 =====================
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# ===================== 4. K维度循环计算GEMM (int8矩阵乘) =====================
for k in range(0, K, BLOCK_K):
# 加载int8数据 (保持int8精度)
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[None, :] < K - k, other=0.0)
# 矩阵乘累加
acc += tl.dot(a, tl.trans(b)) # B: [N,K] -> 转置为[K,N]
# 指针步进
a_ptrs += BLOCK_K * stride_A_K
b_ptrs += BLOCK_K * stride_B_K
# ===================== 5. int8反量化 (Per-Token + Per-Output Channel) =====================
# 加载当前专家的scale
a_scale = tl.load(A_scale_base + offs_m * stride_A_scale_M) # [BLOCK_M]
b_scale = tl.load(B_scale_base + offs_n * stride_B_scale_N) # [BLOCK_N]
# 反量化:out = (int8_mm) * A_scale * B_scale
result = acc * a_scale[:, None] * b_scale[None, :]
# ===================== 6. 写入输出 [E, M, N] =====================
out_ptrs = out_base + (offs_m[:, None] * stride_out_M + offs_n[None, :] * stride_out_N)
# 掩码:只写有效Token + 有效输出通道
mask_m = offs_m < token_cnt
mask_n = offs_n < N
mask = mask_m[:, None] & mask_n[None, :]
tl.store(out_ptrs, result, mask=mask)
# ==============================================
# 包装函数 (对外调用接口,自动处理步长/启动网格)
# ==============================================
def moe_grouped_gemm(
A: torch.Tensor, # [E, M, K]
B: torch.Tensor, # [E, N, K] int8
A_scale: torch.Tensor, # [E, M, 1]
B_scale: torch.Tensor, # [E, N, 1]
token_counts: torch.Tensor, # [E]
output: torch.Tensor, # [E, M, N] (传入,直接写入)
):
# 维度校验
E, M, K = A.shape
_, N, _ = B.shape
assert B.shape == (E, N, K)
assert A_scale.shape == (E, M, 1)
assert B_scale.shape == (E, N, 1)
assert token_counts.shape == (E,)
assert output.shape == (E, M, N)
# 设备统一
assert A.device == B.device == A_scale.device == B_scale.device == token_counts.device == output.device
assert A.is_cuda
# 自动分块大小 (适配主流GPU)
BLOCK_M = 64
BLOCK_N = 64
BLOCK_K = 64
# 计算网格:[E, ceil(M/BLOCK_M), ceil(N/BLOCK_N)]
grid = (
E,
triton.cdiv(M, BLOCK_M),
triton.cdiv(N, BLOCK_N),
)
# 启动内核
moe_grouped_gemm_kernel[grid](
# 数据指针
A, B,
A_scale, B_scale,
token_counts,
output,
# 步长 (按最后一维连续的张量自动计算)
stride_A_E=A.stride(0), stride_A_M=A.stride(1), stride_A_K=A.stride(2),
stride_B_E=B.stride(0), stride_B_N=B.stride(1), stride_B_K=B.stride(2),
stride_A_scale_E=A_scale.stride(0), stride_A_scale_M=A_scale.stride(1),
stride_B_scale_E=B_scale.stride(0), stride_B_scale_N=B_scale.stride(1),
stride_out_E=output.stride(0), stride_out_M=output.stride(1), stride_out_N=output.stride(2),
# 固定维度
E=E, M=M, N=N, K=K,
# 分块参数
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return output
def scales_shape_stride_dtype(
E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT
) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]:
......@@ -297,6 +457,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.N = N
self.K = K
self.act_fn = SiluAndMul()
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
......@@ -414,7 +575,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
use_nn_moe: bool | None = False,
**_
):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
......@@ -436,11 +597,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
expected_m = self.estimate_expected_m(
global_num_experts=global_num_experts,
max_tokens_per_expert=max_num_tokens,
topk=topk_ids.size(-1),
)
# expected_m = self.estimate_expected_m(
# global_num_experts=global_num_experts,
# max_tokens_per_expert=max_num_tokens,
# topk=topk_ids.size(-1),
# )
expected_m = self.get_expected_m()
if self.quant_config.use_fp8_w8a16 or self.quant_config.use_fp8_w8a8:
fp8_m_grouped_gemm_nt_masked(
......
......@@ -297,21 +297,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Dispatch
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
quant_type = 0
if self.use_int8_dispatch:
quant_type = 1
elif self.use_fp8_dispatch:
quant_type = 2
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
a1,
dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch or self.use_int8_dispatch,
use_int8=self.use_int8_dispatch,
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=qc_a1_gscale_or_scale)
if qc_a1_gscale_or_scale is not None
else dict()
),
quant_type = quant_type,
fp8_round_scale=False,
async_finish=False,
return_recv_hook=True,
)
......
......@@ -853,7 +853,7 @@ class FusedMoE(CustomOp):
def use_dp_chunking(self) -> bool:
return (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
#or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
......
......@@ -406,6 +406,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.quant_config = quant_config
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
self.expected_m = max_num_tokens
@staticmethod
def expects_unquantized_inputs(
......@@ -775,6 +776,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
def set_expected_m(self, expected_m):
self.expected_m = expected_m
def get_expected_m(self):
return self.expected_m
def _slice_scales(
scales: torch.Tensor | None, start: int, end: int
......@@ -1074,6 +1081,12 @@ class FusedMoEModularKernel(torch.nn.Module):
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
"""
expected_m = (
hidden_states.shape[0] * self.fused_experts.num_dispatchers * topk_ids.shape[1]
+ global_num_experts
) // global_num_experts
self.fused_experts.set_expected_m(expected_m)
if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
......
......@@ -6,6 +6,7 @@ import numpy as np
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
......@@ -208,7 +209,7 @@ def coordinate_batch_across_dp(
]
"""
if parallel_config.data_parallel_size == 1:
if parallel_config.data_parallel_size == 1 or envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency":
# Early exit.
return False, None, cudagraph_mode
......
......@@ -183,6 +183,7 @@ from .utils import (
sanity_check_mm_encoder_outputs,
)
from vllm.v1.spec_decode.utils import DraftProbs
from vllm.utils.torch_utils import async_tensor_h2d
if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
......@@ -4789,9 +4790,6 @@ class GPUModelRunner(
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
model_kwargs = self._init_model_kwargs()
else:
self.input_ids.gpu[:num_tokens_padded] = torch.randint(0, self.model_config.get_vocab_size(),
(num_tokens_padded,),
dtype=torch.int32)
input_ids = self.input_ids.gpu[:num_tokens_padded]
inputs_embeds = None
......@@ -4904,9 +4902,15 @@ class GPUModelRunner(
self.eplb_step(is_dummy=True, is_profile=is_profile)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
logit_indices_device = torch.from_numpy(logit_indices).to(
self.device, non_blocking=True
)
# logit_indices_device = torch.from_numpy(logit_indices).to(
# self.device, non_blocking=True
# )
logit_indices = logit_indices.tolist()
logit_indices_device = async_tensor_h2d(
logit_indices,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
return hidden_states, hidden_states[logit_indices_device]
@torch.inference_mode()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment