Commit cbf7a3e3 authored by maxiao1's avatar maxiao1
Browse files

Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'

使用groupgemm完成低延迟模式适配

See merge request OpenDAS/sglang!19
parents 3246cea1 e7a1f94c
...@@ -19,8 +19,8 @@ logger = logging.getLogger(__name__) ...@@ -19,8 +19,8 @@ logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var( use_vllm_custom_allreduce = get_bool_env_var(
"USE_VLLM_CUSTOM_ALLREDUCE", default="false" "USE_VLLM_CUSTOM_ALLREDUCE", default="false"
) )
use_dcu_custom_allreduce= get_bool_env_var( use_dcu_custom_allreduce = get_bool_env_var(
"USE_DCU_CUSTOM_ALLREDUCE", default="false" "USE_DCU_CUSTOM_ALLREDUCE", default="true"
) )
if not is_hpu(): if not is_hpu():
......
...@@ -2,7 +2,6 @@ import logging ...@@ -2,7 +2,6 @@ import logging
import torch import torch
import triton import triton
from sglang.srt.utils import ceil_div, is_cuda from sglang.srt.utils import ceil_div, is_cuda
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union ...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch import torch
import torch.distributed as dist
from sglang.srt import single_batch_overlap from sglang.srt import single_batch_overlap
from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
...@@ -39,6 +39,8 @@ if TYPE_CHECKING: ...@@ -39,6 +39,8 @@ if TYPE_CHECKING:
DeepEPNormalOutput, DeepEPNormalOutput,
DispatchOutput, DispatchOutput,
) )
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
_is_hip = is_hip() _is_hip = is_hip()
_is_npu = is_npu() _is_npu = is_npu()
...@@ -494,16 +496,19 @@ class DeepEPMoE(EPMoE): ...@@ -494,16 +496,19 @@ class DeepEPMoE(EPMoE):
f"Dispatch output is not supported" f"Dispatch output is not supported"
) )
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if ( if self.use_w4a8_marlin:
get_moe_runner_backend().is_flashinfer_cutedsl() return self.forward_groupgemm_w4a8_marlin_masked(dispatch_output)
and self.quant_config.get_name() == "modelopt_fp4" else:
): if (
return self.forward_flashinfer_cutedsl( get_moe_runner_backend().is_flashinfer_cutedsl()
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args and self.quant_config.get_name() == "modelopt_fp4"
) ):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 return self.forward_flashinfer_cutedsl(
assert down_gemm_overlap_args is None dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
return self.forward_deepgemm_masked(dispatch_output) )
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
assert down_gemm_overlap_args is None
return self.forward_deepgemm_masked(dispatch_output)
else: else:
raise ValueError( raise ValueError(
f"Dispatch output format {dispatch_output.format} is not supported" f"Dispatch output format {dispatch_output.format} is not supported"
...@@ -561,29 +566,35 @@ class DeepEPMoE(EPMoE): ...@@ -561,29 +566,35 @@ class DeepEPMoE(EPMoE):
self, self,
dispatch_output: DeepEPNormalOutput, dispatch_output: DeepEPNormalOutput,
): ):
hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = ( hidden_states, hidden_states_scale, topk_idx, topk_weights, num_recv_tokens_per_expert = (
dispatch_output dispatch_output
) )
#hidden_states_int8, hidden_states_scale = hidden_states_int8
assert self.quant_method is not None assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu" assert self.moe_runner_config.activation == "silu"
# if num_recv_tokens_per_expert is None: all_tokens = sum(num_recv_tokens_per_expert)
return hidden_states_int8.bfloat16()
# expert_output = self.quant_method.apply_ep( if all_tokens <= 0:
# layer=self, return hidden_states.bfloat16()
# x=dispatch_output, num_local_tokens = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int32, device="cuda")
# topk_weights=topk_weights, expert_output = self.quant_method.apply_ep(
# topk_ids=topk_idx, x=hidden_states,
# global_num_experts=self.global_num_experts, w1=self.w13_weight,
# expert_map=self.expert_map, w2=self.w2_weight,
# activation=self.activation, topk_ids=topk_idx,
# apply_router_weight_on_input=self.apply_router_weight_on_input, topk_weights=topk_weights,
# use_nn_moe=self.use_nn_moe, global_num_experts=self.moe_runner_config.num_experts,
# num_local_tokens=dispatch_recv_num_token, expert_map=self.expert_map,
# config_select_bs=hidden_states.shape[0], activation=self.moe_runner_config.activation,
# scales=dispatch_scales if self.use_int8_dispatch else None apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
# # routed_scaling_factor=self.routed_scaling_factor, use_nn_moe=False,
# ) w1_scale=self.w13_weight_scale,
# return expert_output w2_scale=self.w2_weight_scale,
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor,
# num_local_tokens=num_local_tokens,
)
return expert_output
def forward_deepgemm_contiguous( def forward_deepgemm_contiguous(
self, self,
...@@ -763,6 +774,56 @@ class DeepEPMoE(EPMoE): ...@@ -763,6 +774,56 @@ class DeepEPMoE(EPMoE):
dispatch_output=dispatch_output, dispatch_output=dispatch_output,
) )
def forward_groupgemm_w4a8_marlin_masked(
self,
dispatch_output: DeepEPLLOutput,
):
hidden_states, _, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# base shapes
num_groups, m, k = hidden_states.size()
expected_m = m // 2 # 算子要求形状
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states)
# ---- weights & scales ----
w13_weight = self.w13_weight
w13_scales = self.w13_weight_scale
w2_weight = self.w2_weight
w2_scales = self.w2_weight_scale
n1 = w13_scales.size(1)
gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16)
# ---- first GEMM ----
m_grouped_w4a8_gemm_nt_masked(
(q_a1_all, q_a1_scale),
(w13_weight, w13_scales),
gateup_output,
masked_m,
expected_m,
)
q_a2_all, q_a2_scale = fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ----
n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
m_grouped_w4a8_gemm_nt_masked(
(q_a2_all, q_a2_scale),
(w2_weight, w2_scales),
down_output,
masked_m,
expected_m,
)
return down_output
def forward_deepgemm_masked( def forward_deepgemm_masked(
self, self,
dispatch_output: DeepEPLLOutput, dispatch_output: DeepEPLLOutput,
......
...@@ -98,6 +98,49 @@ class FusedMoeWeightScaleSupported(Enum): ...@@ -98,6 +98,49 @@ class FusedMoeWeightScaleSupported(Enum):
GROUP = "group" GROUP = "group"
BLOCK = "block" BLOCK = "block"
def determine_expert_map(
ep_size: int, ep_rank: int,
global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
distributed evenly across ranks. Any remaining are assigned to the
last rank.
Args:
ep_size (int): The size of the expert parallel group
global_num_experts (int): The total number of experts in the model.
Returns:
tuple[int, Optional[torch.Tensor]]: A tuple containing:
- local_num_experts (int): The number of experts assigned
to the current rank.
- expert_map (Optional[torch.Tensor]): A tensor of shape
(global_num_experts,) mapping from global to local index.
Contains -1 for experts not assigned to the current rank.
Returns None if ep_size is 1.
"""
assert ep_size > 0
if ep_size == 1:
return (global_num_experts, None)
local_num_experts = global_num_experts // ep_size
# Create a tensor of size num_experts filled with -1
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
# Create a expert map for the local experts
if ep_rank < (ep_size - 1):
# Each non-last rank gets local_num_experts experts.
expert_map[ep_rank * local_num_experts:
(ep_rank + 1) * local_num_experts] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
else:
# All remaining experts are assigned to the last rank.
local_num_experts = (global_num_experts - ep_rank * local_num_experts)
expert_map[-local_num_experts:] = \
torch.arange(0, local_num_experts, dtype=torch.int32)
return (local_num_experts, expert_map)
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
...@@ -165,8 +208,15 @@ class FusedMoE(torch.nn.Module): ...@@ -165,8 +208,15 @@ class FusedMoE(torch.nn.Module):
self.moe_tp_size = get_moe_tensor_parallel_world_size() self.moe_tp_size = get_moe_tensor_parallel_world_size()
self.moe_tp_rank = get_moe_tensor_parallel_rank() self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0 assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size # self.num_local_experts = num_experts // self.moe_ep_size
if self.moe_ep_size != 0:
self.num_local_experts, self.expert_map = determine_expert_map(
ep_size=self.moe_ep_size,
ep_rank=self.moe_ep_rank,
global_num_experts=num_experts)
else:
self.local_num_experts, self.expert_map = (self.global_num_experts,
None)
assert intermediate_size % self.moe_tp_size == 0 assert intermediate_size % self.moe_tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
......
...@@ -55,6 +55,10 @@ import torch.distributed as dist ...@@ -55,6 +55,10 @@ import torch.distributed as dist
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
use_groupgemm = get_bool_env_var(
"SGLANG_GROUPGEMM", default="true"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -606,27 +610,40 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -606,27 +610,40 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
use_fp8 = True use_fp8 = True
buffer = self._get_buffer() buffer = self._get_buffer()
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = ( if use_groupgemm:
buffer.low_latency_dispatch( packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
hidden_states, buffer.low_latency_dispatch(
topk_ids, hidden_states,
self.num_max_dispatch_tokens_per_rank, topk_ids,
self.num_experts, self.num_max_dispatch_tokens_per_rank,
use_fp8=use_fp8, self.num_experts,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()), use_fp8=False,
**( async_finish=not self.return_recv_hook,
dict(x_global_scale=input_global_scale) return_recv_hook=self.return_recv_hook,
if input_global_scale is not None )
else dict() )
), else:
async_finish=not self.return_recv_hook, packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
return_recv_hook=self.return_recv_hook, buffer.low_latency_dispatch(
round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM hidden_states,
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL, topk_ids,
use_ue8m0=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM self.num_max_dispatch_tokens_per_rank,
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL, self.num_experts,
use_fp8=False,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=input_global_scale)
if input_global_scale is not None
else dict()
),
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
round_scale=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
use_ue8m0=deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
)
) )
)
return packed_recv_hidden, self.packed_recv_count, event, hook return packed_recv_hidden, self.packed_recv_count, event, hook
def combine_a( def combine_a(
...@@ -670,24 +687,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -670,24 +687,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
ctx = torch.cuda.stream(overlap_args.stream) ctx = torch.cuda.stream(overlap_args.stream)
with ctx: with ctx:
combined_hidden_states, event, hook = buffer.low_latency_combine( if use_groupgemm:
x=hidden_states, combined_hidden_states, event, hook = buffer.low_latency_combine(
topk_idx=topk_ids, x=hidden_states,
topk_weights=topk_weights, topk_idx=topk_ids,
handle=self.handle, topk_weights=topk_weights,
async_finish=not self.return_recv_hook, handle=self.handle,
return_recv_hook=self.return_recv_hook, zero_copy=False,
**( async_finish=not self.return_recv_hook,
dict( return_recv_hook=self.return_recv_hook,
overlap=overlap_args.overlap, )
src_signals=overlap_args.signal, else:
src_signal_expect_value=overlap_args.threshold, combined_hidden_states, event, hook = buffer.low_latency_combine(
) x=hidden_states,
if overlap_args is not None topk_idx=topk_ids,
else {} topk_weights=topk_weights,
), handle=self.handle,
) zero_copy=False,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
**(
dict(
overlap=overlap_args.overlap,
src_signals=overlap_args.signal,
src_signal_expect_value=overlap_args.threshold,
)
if overlap_args is not None
else {}
),
)
self.packed_recv_count = self.handle = None self.packed_recv_count = self.handle = None
return combined_hidden_states, event, hook return combined_hidden_states, event, hook
......
...@@ -317,3 +317,55 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -317,3 +317,55 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# a2_scale=layer.w2_input_scale, # a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe, # use_nn_moe=use_nn_moe,
# ) # )
#
def apply_ep(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
num_local_tokens: Optional[torch.Tensor] = None,
#config_select_bs: Optional[int] = None,
routed_scaling_factor: Optional[float] = 1.0,
shared_output: Optional[torch.Tensor] = None,
#scales: Optional[torch.Tensor] = None,
num_recv_tokens_per_expert: List = None,
**_ ):
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
w1,
w2,
topk_ids=topk_ids,
topk_weights=topk_weights,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=w1_scale,
w2_scale=w2_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
# num_local_tokens=num_local_tokens,
#config_select_bs=config_select_bs,
#q_scales=scales
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment