Commit 92f82dce authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev_yiqa' into 'v0.5.4_dev'

使用groupgemm完成高吞吐模式适配。

See merge request OpenDAS/sglang!24
parents a34b0d3d 54abdab4
......@@ -563,7 +563,7 @@ class DeepEPMoE(EPMoE):
)
def forward_deepgemm_w4a8_marlin_contiguous(
self,
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states, hidden_states_scale, topk_idx, topk_weights, num_recv_tokens_per_expert = (
......@@ -576,7 +576,6 @@ class DeepEPMoE(EPMoE):
if all_tokens <= 0:
return hidden_states.bfloat16()
num_local_tokens = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int32, device="cuda")
expert_output = self.quant_method.apply_ep(
x=hidden_states,
w1=self.w13_weight,
......@@ -591,10 +590,9 @@ class DeepEPMoE(EPMoE):
w1_scale=self.w13_weight_scale,
w2_scale=self.w2_weight_scale,
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor,
# num_local_tokens=num_local_tokens,
)
return expert_output
def forward_deepgemm_contiguous(
self,
......@@ -807,11 +805,11 @@ class DeepEPMoE(EPMoE):
masked_m,
expected_m,
)
q_a2_all, q_a2_scale = fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ----
n2 = w2_scales.size(1)
# ---- second GEMM ----
n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
m_grouped_w4a8_gemm_nt_masked(
......
......@@ -4,7 +4,7 @@ import logging
from contextlib import nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.distributed import get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
......@@ -357,18 +357,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
):
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
topk_ids = topk_ids.to(torch.int64)
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and not get_moe_runner_backend().is_cutlass()
):
# TODO hard code 128 block quant,use fp8 communication
hidden_states = sglang_per_token_group_quant_fp8(
hidden_states,
128,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
# if (
# deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
# and not get_moe_runner_backend().is_cutlass()
# ):
# # TODO hard code 128 block quant,use fp8 communication
# hidden_states = sglang_per_token_group_quant_fp8(
# hidden_states,
# 128,
# column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# )
previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_ids, topk_weights, previous_event
......@@ -380,7 +380,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
num_recv_tokens_per_expert,
event,
) = self._dispatch_core(hidden_states, topk_ids, topk_weights, previous_event)
event.current_stream_wait() if self.async_finish else ()
# event.current_stream_wait() if self.async_finish else ()
if isinstance(hidden_states, tuple):
hidden_states, hidden_states_scale = hidden_states
......@@ -435,18 +435,23 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event,
async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
previous_event=None,
async_finish=False,
allocate_on_comm_stream=False,
expert_alignment=1,
config=DeepEPConfig.get_instance().normal_dispatch_config,
)
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
# get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
# num_recv_tokens_per_expert,
# num_tokens_per_rank=num_tokens_per_rank,
# num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
# num_tokens_per_expert=num_tokens_per_expert,
# )
self.rank_expert_offset= get_moe_expert_parallel_rank() * ( self.num_experts // get_moe_expert_parallel_world_size())
recv_topk_ids = torch.where(
recv_topk_ids == -1,
self.num_experts - 1 if self.rank_expert_offset == 0 else 0,
recv_topk_ids + self.rank_expert_offset)
return (
recv_x,
......@@ -495,7 +500,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
def combine_b(self, output, previous_event):
hidden_states, event = self._combine_core(output, previous_event)
event.current_stream_wait() if self.async_finish else ()
# event.current_stream_wait() if self.async_finish else ()
self.handle = None
self.src2dst = None
return hidden_states
......@@ -505,9 +510,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
combined_x, _, event = buffer.combine(
x,
self.handle,
async_finish=self.async_finish,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None,
async_finish=False,
previous_event=None,
allocate_on_comm_stream=False,
config=DeepEPConfig.get_instance().normal_combine_config,
)
return combined_x, event
......@@ -536,7 +541,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
"""
self.return_recv_hook = return_recv_hook
self.return_recv_hook = False
self.device_module = torch.get_device_module()
self.quant_config = {}
......@@ -693,7 +698,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_idx=topk_ids,
topk_weights=topk_weights,
handle=self.handle,
zero_copy=False,
zero_copy=False,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
)
......@@ -703,7 +708,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
topk_idx=topk_ids,
topk_weights=topk_weights,
handle=self.handle,
zero_copy=False,
zero_copy=False,
async_finish=not self.return_recv_hook,
return_recv_hook=self.return_recv_hook,
**(
......
......@@ -42,7 +42,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
packed_modules_mapping: Optional[dict[str, list[str]]] = None,
packed_modules_mapping: Optional[dict[str, list[str]]] = None,
):
super().__init__(
target_scheme_map,
......@@ -52,10 +52,10 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
sparsity_ignore_list,
kv_cache_scheme,
config,
packed_modules_mapping,
packed_modules_mapping,
)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[str]:
......@@ -73,7 +73,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE # Avoid circular import
# from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.radix_attention import RadixAttention
# Check if the layer is skipped for quantization.
if should_ignore_layer(prefix,
ignore=self.ignore,
......@@ -85,8 +85,8 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
# if isinstance(layer, RadixAttention):
# return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, RadixAttention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMarlinMoEMethod.get_moe_method(self, layer)
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment