Commit 92f82dce authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev_yiqa' into 'v0.5.4_dev'

使用groupgemm完成高吞吐模式适配。

See merge request OpenDAS/sglang!24
parents a34b0d3d 54abdab4
...@@ -563,7 +563,7 @@ class DeepEPMoE(EPMoE): ...@@ -563,7 +563,7 @@ class DeepEPMoE(EPMoE):
) )
def forward_deepgemm_w4a8_marlin_contiguous( def forward_deepgemm_w4a8_marlin_contiguous(
self, self,
dispatch_output: DeepEPNormalOutput, dispatch_output: DeepEPNormalOutput,
): ):
hidden_states, hidden_states_scale, topk_idx, topk_weights, num_recv_tokens_per_expert = ( hidden_states, hidden_states_scale, topk_idx, topk_weights, num_recv_tokens_per_expert = (
...@@ -576,7 +576,6 @@ class DeepEPMoE(EPMoE): ...@@ -576,7 +576,6 @@ class DeepEPMoE(EPMoE):
if all_tokens <= 0: if all_tokens <= 0:
return hidden_states.bfloat16() return hidden_states.bfloat16()
num_local_tokens = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int32, device="cuda")
expert_output = self.quant_method.apply_ep( expert_output = self.quant_method.apply_ep(
x=hidden_states, x=hidden_states,
w1=self.w13_weight, w1=self.w13_weight,
...@@ -591,10 +590,9 @@ class DeepEPMoE(EPMoE): ...@@ -591,10 +590,9 @@ class DeepEPMoE(EPMoE):
w1_scale=self.w13_weight_scale, w1_scale=self.w13_weight_scale,
w2_scale=self.w2_weight_scale, w2_scale=self.w2_weight_scale,
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor, routed_scaling_factor = self.moe_runner_config.routed_scaling_factor,
# num_local_tokens=num_local_tokens,
) )
return expert_output return expert_output
def forward_deepgemm_contiguous( def forward_deepgemm_contiguous(
self, self,
...@@ -807,11 +805,11 @@ class DeepEPMoE(EPMoE): ...@@ -807,11 +805,11 @@ class DeepEPMoE(EPMoE):
masked_m, masked_m,
expected_m, expected_m,
) )
q_a2_all, q_a2_scale = fuse_silu_mul_quant_ep(gateup_output, masked_m) q_a2_all, q_a2_scale = fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ---- # ---- second GEMM ----
n2 = w2_scales.size(1) n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16) down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
m_grouped_w4a8_gemm_nt_masked( m_grouped_w4a8_gemm_nt_masked(
......
...@@ -4,7 +4,7 @@ import logging ...@@ -4,7 +4,7 @@ import logging
from contextlib import nullcontext from contextlib import nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.distributed import get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.dp_attention import get_is_extend_in_batch from sglang.srt.layers.dp_attention import get_is_extend_in_batch
...@@ -357,18 +357,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -357,18 +357,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
): ):
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
topk_ids = topk_ids.to(torch.int64) topk_ids = topk_ids.to(torch.int64)
if ( # if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM # deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and not get_moe_runner_backend().is_cutlass() # and not get_moe_runner_backend().is_cutlass()
): # ):
# TODO hard code 128 block quant,use fp8 communication # # TODO hard code 128 block quant,use fp8 communication
hidden_states = sglang_per_token_group_quant_fp8( # hidden_states = sglang_per_token_group_quant_fp8(
hidden_states, # hidden_states,
128, # 128,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, # column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, # scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, # scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
) # )
previous_event = Buffer.capture() if self.async_finish else None previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_ids, topk_weights, previous_event return hidden_states, topk_ids, topk_weights, previous_event
...@@ -380,7 +380,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -380,7 +380,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
num_recv_tokens_per_expert, num_recv_tokens_per_expert,
event, event,
) = self._dispatch_core(hidden_states, topk_ids, topk_weights, previous_event) ) = self._dispatch_core(hidden_states, topk_ids, topk_weights, previous_event)
event.current_stream_wait() if self.async_finish else () # event.current_stream_wait() if self.async_finish else ()
if isinstance(hidden_states, tuple): if isinstance(hidden_states, tuple):
hidden_states, hidden_states_scale = hidden_states hidden_states, hidden_states_scale = hidden_states
...@@ -435,18 +435,23 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -435,18 +435,23 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=num_tokens_per_expert, num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event, previous_event=None,
async_finish=self.async_finish, async_finish=False,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish, allocate_on_comm_stream=False,
expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1, expert_alignment=1,
config=DeepEPConfig.get_instance().normal_dispatch_config, config=DeepEPConfig.get_instance().normal_dispatch_config,
) )
get_global_expert_distribution_recorder().on_deepep_dispatch_normal( # get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert, # num_recv_tokens_per_expert,
num_tokens_per_rank=num_tokens_per_rank, # num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, # num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert, # num_tokens_per_expert=num_tokens_per_expert,
) # )
self.rank_expert_offset= get_moe_expert_parallel_rank() * ( self.num_experts // get_moe_expert_parallel_world_size())
recv_topk_ids = torch.where(
recv_topk_ids == -1,
self.num_experts - 1 if self.rank_expert_offset == 0 else 0,
recv_topk_ids + self.rank_expert_offset)
return ( return (
recv_x, recv_x,
...@@ -495,7 +500,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -495,7 +500,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def combine_b(self, output, previous_event): def combine_b(self, output, previous_event):
hidden_states, event = self._combine_core(output, previous_event) hidden_states, event = self._combine_core(output, previous_event)
event.current_stream_wait() if self.async_finish else () # event.current_stream_wait() if self.async_finish else ()
self.handle = None self.handle = None
self.src2dst = None self.src2dst = None
return hidden_states return hidden_states
...@@ -505,9 +510,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -505,9 +510,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
combined_x, _, event = buffer.combine( combined_x, _, event = buffer.combine(
x, x,
self.handle, self.handle,
async_finish=self.async_finish, async_finish=False,
previous_event=previous_event, previous_event=None,
allocate_on_comm_stream=previous_event is not None, allocate_on_comm_stream=False,
config=DeepEPConfig.get_instance().normal_combine_config, config=DeepEPConfig.get_instance().normal_combine_config,
) )
return combined_x, event return combined_x, event
...@@ -536,7 +541,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -536,7 +541,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
""" """
self.return_recv_hook = return_recv_hook self.return_recv_hook = False
self.device_module = torch.get_device_module() self.device_module = torch.get_device_module()
self.quant_config = {} self.quant_config = {}
...@@ -693,7 +698,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -693,7 +698,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_idx=topk_ids, topk_idx=topk_ids,
topk_weights=topk_weights, topk_weights=topk_weights,
handle=self.handle, handle=self.handle,
zero_copy=False, zero_copy=False,
async_finish=not self.return_recv_hook, async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook, return_recv_hook=self.return_recv_hook,
) )
...@@ -703,7 +708,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -703,7 +708,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_idx=topk_ids, topk_idx=topk_ids,
topk_weights=topk_weights, topk_weights=topk_weights,
handle=self.handle, handle=self.handle,
zero_copy=False, zero_copy=False,
async_finish=not self.return_recv_hook, async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook, return_recv_hook=self.return_recv_hook,
**( **(
......
...@@ -42,7 +42,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig): ...@@ -42,7 +42,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
sparsity_ignore_list: list[str], sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None, kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None, config: Optional[dict[str, Any]] = None,
packed_modules_mapping: Optional[dict[str, list[str]]] = None, packed_modules_mapping: Optional[dict[str, list[str]]] = None,
): ):
super().__init__( super().__init__(
target_scheme_map, target_scheme_map,
...@@ -52,10 +52,10 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig): ...@@ -52,10 +52,10 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
sparsity_ignore_list, sparsity_ignore_list,
kv_cache_scheme, kv_cache_scheme,
config, config,
packed_modules_mapping, packed_modules_mapping,
) )
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[str]:
...@@ -73,7 +73,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig): ...@@ -73,7 +73,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE # Avoid circular import from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE # Avoid circular import
# from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
# Check if the layer is skipped for quantization. # Check if the layer is skipped for quantization.
if should_ignore_layer(prefix, if should_ignore_layer(prefix,
ignore=self.ignore, ignore=self.ignore,
...@@ -85,8 +85,8 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig): ...@@ -85,8 +85,8 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod() return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
layer.scheme = scheme layer.scheme = scheme
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
# if isinstance(layer, RadixAttention): if isinstance(layer, RadixAttention):
# return CompressedTensorsKVCacheMethod(self) return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
return CompressedTensorsMarlinMoEMethod.get_moe_method(self, layer) return CompressedTensorsMarlinMoEMethod.get_moe_method(self, layer)
return None return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment