Commit e7a1f94c authored by maxiao1's avatar maxiao1
Browse files

使用groupgemm完成低延迟模式适配

parent 3246cea1
......@@ -19,8 +19,8 @@ logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var(
"USE_VLLM_CUSTOM_ALLREDUCE", default="false"
)
use_dcu_custom_allreduce= get_bool_env_var(
"USE_DCU_CUSTOM_ALLREDUCE", default="false"
use_dcu_custom_allreduce = get_bool_env_var(
"USE_DCU_CUSTOM_ALLREDUCE", default="true"
)
if not is_hpu():
......
......@@ -2,7 +2,6 @@ import logging
import torch
import triton
from sglang.srt.utils import ceil_div, is_cuda
logger = logging.getLogger(__name__)
......
......@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch
import torch.distributed as dist
from sglang.srt import single_batch_overlap
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe import (
......@@ -39,6 +39,8 @@ if TYPE_CHECKING:
DeepEPNormalOutput,
DispatchOutput,
)
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
_is_hip = is_hip()
_is_npu = is_npu()
......@@ -494,16 +496,19 @@ class DeepEPMoE(EPMoE):
f"Dispatch output is not supported"
)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if (
get_moe_runner_backend().is_flashinfer_cutedsl()
and self.quant_config.get_name() == "modelopt_fp4"
):
return self.forward_flashinfer_cutedsl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
assert down_gemm_overlap_args is None
return self.forward_deepgemm_masked(dispatch_output)
if self.use_w4a8_marlin:
return self.forward_groupgemm_w4a8_marlin_masked(dispatch_output)
else:
if (
get_moe_runner_backend().is_flashinfer_cutedsl()
and self.quant_config.get_name() == "modelopt_fp4"
):
return self.forward_flashinfer_cutedsl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
assert down_gemm_overlap_args is None
return self.forward_deepgemm_masked(dispatch_output)
else:
raise ValueError(
f"Dispatch output format {dispatch_output.format} is not supported"
......@@ -561,29 +566,35 @@ class DeepEPMoE(EPMoE):
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
hidden_states, hidden_states_scale, topk_idx, topk_weights, num_recv_tokens_per_expert = (
dispatch_output
)
#hidden_states_int8, hidden_states_scale = hidden_states_int8
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# if num_recv_tokens_per_expert is None:
return hidden_states_int8.bfloat16()
# expert_output = self.quant_method.apply_ep(
# layer=self,
# x=dispatch_output,
# topk_weights=topk_weights,
# topk_ids=topk_idx,
# global_num_experts=self.global_num_experts,
# expert_map=self.expert_map,
# activation=self.activation,
# apply_router_weight_on_input=self.apply_router_weight_on_input,
# use_nn_moe=self.use_nn_moe,
# num_local_tokens=dispatch_recv_num_token,
# config_select_bs=hidden_states.shape[0],
# scales=dispatch_scales if self.use_int8_dispatch else None
# # routed_scaling_factor=self.routed_scaling_factor,
# )
# return expert_output
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states.bfloat16()
num_local_tokens = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int32, device="cuda")
expert_output = self.quant_method.apply_ep(
x=hidden_states,
w1=self.w13_weight,
w2=self.w2_weight,
topk_ids=topk_idx,
topk_weights=topk_weights,
global_num_experts=self.moe_runner_config.num_experts,
expert_map=self.expert_map,
activation=self.moe_runner_config.activation,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
use_nn_moe=False,
w1_scale=self.w13_weight_scale,
w2_scale=self.w2_weight_scale,
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor,
# num_local_tokens=num_local_tokens,
)
return expert_output
def forward_deepgemm_contiguous(
self,
......@@ -763,6 +774,56 @@ class DeepEPMoE(EPMoE):
dispatch_output=dispatch_output,
)
def forward_groupgemm_w4a8_marlin_masked(
self,
dispatch_output: DeepEPLLOutput,
):
hidden_states, _, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# base shapes
num_groups, m, k = hidden_states.size()
expected_m = m // 2 # 算子要求形状
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states)
# ---- weights & scales ----
w13_weight = self.w13_weight
w13_scales = self.w13_weight_scale
w2_weight = self.w2_weight
w2_scales = self.w2_weight_scale
n1 = w13_scales.size(1)
gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16)
# ---- first GEMM ----
m_grouped_w4a8_gemm_nt_masked(
(q_a1_all, q_a1_scale),
(w13_weight, w13_scales),
gateup_output,
masked_m,
expected_m,
)
q_a2_all, q_a2_scale = fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ----
n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
m_grouped_w4a8_gemm_nt_masked(
(q_a2_all, q_a2_scale),
(w2_weight, w2_scales),
down_output,
masked_m,
expected_m,
)
return down_output
def forward_deepgemm_masked(
self,
dispatch_output: DeepEPLLOutput,
......
......@@ -98,6 +98,49 @@ class FusedMoeWeightScaleSupported(Enum):
GROUP = "group"
BLOCK = "block"
def determine_expert_map(
ep_size: int, ep_rank: int,
global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
distributed evenly across ranks. Any remaining are assigned to the
last rank.
Args:
ep_size (int): The size of the expert parallel group
global_num_experts (int): The total number of experts in the model.
Returns:
tuple[int, Optional[torch.Tensor]]: A tuple containing:
- local_num_experts (int): The number of experts assigned
to the current rank.
- expert_map (Optional[torch.Tensor]): A tensor of shape
(global_num_experts,) mapping from global to local index.
Contains -1 for experts not assigned to the current rank.
Returns None if ep_size is 1.
"""
assert ep_size > 0
if ep_size == 1:
return (global_num_experts, None)
local_num_experts = global_num_experts // ep_size
# Create a tensor of size num_experts filled with -1
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
# Create a expert map for the local experts
if ep_rank < (ep_size - 1):
# Each non-last rank gets local_num_experts experts.
expert_map[ep_rank * local_num_experts:
(ep_rank + 1) * local_num_experts] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
else:
# All remaining experts are assigned to the last rank.
local_num_experts = (global_num_experts - ep_rank * local_num_experts)
expert_map[-local_num_experts:] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
return (local_num_experts, expert_map)
class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.
......@@ -165,8 +208,15 @@ class FusedMoE(torch.nn.Module):
self.moe_tp_size = get_moe_tensor_parallel_world_size()
self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size
# self.num_local_experts = num_experts // self.moe_ep_size
if self.moe_ep_size != 0:
self.num_local_experts, self.expert_map = determine_expert_map(
ep_size=self.moe_ep_size,
ep_rank=self.moe_ep_rank,
global_num_experts=num_experts)
else:
self.local_num_experts, self.expert_map = (self.global_num_experts,
None)
assert intermediate_size % self.moe_tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results
......
......@@ -55,6 +55,10 @@ import torch.distributed as dist
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
use_groupgemm = get_bool_env_var(
"SGLANG_GROUPGEMM", default="true"
)
logger = logging.getLogger(__name__)
......@@ -606,27 +610,40 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
use_fp8 = True
buffer = self._get_buffer()
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
buffer.low_latency_dispatch(
hidden_states,
topk_ids,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
use_fp8=use_fp8,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=input_global_scale)
if input_global_scale is not None
else dict()
),
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
use_ue8m0=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
if use_groupgemm:
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
buffer.low_latency_dispatch(
hidden_states,
topk_ids,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
use_fp8=False,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
)
)
else:
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
buffer.low_latency_dispatch(
hidden_states,
topk_ids,
self.num_max_dispatch_tokens_per_rank,
self.num_experts,
use_fp8=False,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=input_global_scale)
if input_global_scale is not None
else dict()
),
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
use_ue8m0=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
)
)
)
return packed_recv_hidden, self.packed_recv_count, event, hook
def combine_a(
......@@ -670,24 +687,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
ctx = torch.cuda.stream(overlap_args.stream)
with ctx:
combined_hidden_states, event, hook = buffer.low_latency_combine(
x=hidden_states,
topk_idx=topk_ids,
topk_weights=topk_weights,
handle=self.handle,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
**(
dict(
overlap=overlap_args.overlap,
src_signals=overlap_args.signal,
src_signal_expect_value=overlap_args.threshold,
)
if overlap_args is not None
else {}
),
)
if use_groupgemm:
combined_hidden_states, event, hook = buffer.low_latency_combine(
x=hidden_states,
topk_idx=topk_ids,
topk_weights=topk_weights,
handle=self.handle,
zero_copy=False,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
)
else:
combined_hidden_states, event, hook = buffer.low_latency_combine(
x=hidden_states,
topk_idx=topk_ids,
topk_weights=topk_weights,
handle=self.handle,
zero_copy=False,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
**(
dict(
overlap=overlap_args.overlap,
src_signals=overlap_args.signal,
src_signal_expect_value=overlap_args.threshold,
)
if overlap_args is not None
else {}
),
)
self.packed_recv_count = self.handle = None
return combined_hidden_states, event, hook
......
......@@ -317,3 +317,55 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe,
# )
#
def apply_ep(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
#config_select_bs: Optional[int] = None,
routed_scaling_factor: Optional[float] = 1.0,
shared_output: Optional[torch.Tensor] = None,
#scales: Optional[torch.Tensor] = None,
num_recv_tokens_per_expert: List = None,
**_ ):
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
w1,
w2,
topk_ids=topk_ids,
topk_weights=topk_weights,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=w1_scale,
w2_scale=w2_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
# num_local_tokens=num_local_tokens,
#config_select_bs=config_select_bs,
#q_scales=scales
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment