Commit e3c76844 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev_yiqa' into 'v0.5.4_dev'

V0.5.4 dev yiqa

See merge request OpenDAS/sglang!40
parents 078de197 0a71d6b1
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from collections import defaultdict from collections import defaultdict
from sglang.srt.distributed import get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch import torch
...@@ -119,8 +120,8 @@ def fuse_silu_mul_quant_ep_wrapper( ...@@ -119,8 +120,8 @@ def fuse_silu_mul_quant_ep_wrapper(
topk, topk,
expect_m expect_m
) )
def fuse_silu_mul_quant_ep_fake( def fuse_silu_mul_quant_ep_fake(
input: torch.Tensor, input: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None, tokens_per_expert: Optional[torch.Tensor] = None,
...@@ -695,6 +696,11 @@ class DeepEPMoE(EPMoE): ...@@ -695,6 +696,11 @@ class DeepEPMoE(EPMoE):
if all_tokens <= 0: if all_tokens <= 0:
return hidden_states.bfloat16() return hidden_states.bfloat16()
rank_expert_offset = get_moe_expert_parallel_rank() * (self.num_experts // get_moe_expert_parallel_world_size())
topk_idx = torch.where(
topk_idx == -1,
self.num_experts - 1 if rank_expert_offset == 0 else 0,
topk_idx + rank_expert_offset)
expert_output = self.quant_method.apply_ep( expert_output = self.quant_method.apply_ep(
x=hidden_states, x=hidden_states,
w1=self.w13_weight, w1=self.w13_weight,
...@@ -708,6 +714,7 @@ class DeepEPMoE(EPMoE): ...@@ -708,6 +714,7 @@ class DeepEPMoE(EPMoE):
use_nn_moe=False, use_nn_moe=False,
w1_scale=self.w13_weight_scale, w1_scale=self.w13_weight_scale,
w2_scale=self.w2_weight_scale, w2_scale=self.w2_weight_scale,
a1_scale=hidden_states_scale,
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor, routed_scaling_factor = self.moe_runner_config.routed_scaling_factor,
) )
return expert_output return expert_output
...@@ -740,10 +747,9 @@ class DeepEPMoE(EPMoE): ...@@ -740,10 +747,9 @@ class DeepEPMoE(EPMoE):
active_experts.add(e) active_experts.add(e)
token_expert_pos[t] = lst token_expert_pos[t] = lst
active_experts = sorted(list(active_experts)) if not active_experts:
num_active = len(active_experts)
if num_active == 0:
return hidden_states.bfloat16() return hidden_states.bfloat16()
active_experts = sorted(list(active_experts))
counts = defaultdict(int) counts = defaultdict(int)
for t in range(M): for t in range(M):
...@@ -752,12 +758,9 @@ class DeepEPMoE(EPMoE): ...@@ -752,12 +758,9 @@ class DeepEPMoE(EPMoE):
per_expert_block = {} per_expert_block = {}
for e in active_experts: for e in active_experts:
cnt = counts.get(e, 0) cnt = counts[e]
if cnt <= 0: needed = ((cnt + 255) // 256) * 256 # same as ceil(cnt/256)*256
per_expert_block[e] = 0 per_expert_block[e] = max(256, needed)
else:
needed = ((cnt + 256 - 1) // 256) * 256 # next multiple of 256
per_expert_block[e] = max(256, needed)
expert_slot_offset = {} expert_slot_offset = {}
offset = 0 offset = 0
...@@ -766,7 +769,8 @@ class DeepEPMoE(EPMoE): ...@@ -766,7 +769,8 @@ class DeepEPMoE(EPMoE):
offset += per_expert_block[e] offset += per_expert_block[e]
pad_M = offset pad_M = offset
hidden_states_packed = torch.zeros((pad_M, K), device=device, dtype=hidden_states.dtype) hidden_states_packed = torch.empty((pad_M, K), device=device, dtype=hidden_states.dtype)
hidden_states_scale_packed = torch.empty((pad_M,), device=device, dtype=hidden_states_scale.dtype)
m_indices = torch.full((pad_M,), -1, device=device, dtype=torch.int32) m_indices = torch.full((pad_M,), -1, device=device, dtype=torch.int32)
slot_counters = {e: 0 for e in active_experts} slot_counters = {e: 0 for e in active_experts}
...@@ -776,26 +780,27 @@ class DeepEPMoE(EPMoE): ...@@ -776,26 +780,27 @@ class DeepEPMoE(EPMoE):
for (e, pos) in token_expert_pos[t]: for (e, pos) in token_expert_pos[t]:
start = expert_slot_offset[e] start = expert_slot_offset[e]
slot = slot_counters[e] slot = slot_counters[e]
if slot >= per_expert_block[e]:
raise RuntimeError(f"Internal error: expert {e} slot {slot} >= block {per_expert_block[e]}")
row = start + slot row = start + slot
hidden_states_packed[row] = hidden_states[t] hidden_states_packed[row] = hidden_states[t]
m_indices[row] = int(e) hidden_states_scale_packed[row] = hidden_states_scale[t]
m_indices[row] = e
slot_counters[e] += 1 slot_counters[e] += 1
w = topk_weights[t, pos].to(device=device) # record weight (as float32 on device)
w = topk_weights[t, pos]
w_f = w.float() if w.dtype != torch.float32 else w w_f = w.float() if w.dtype != torch.float32 else w
token_row_weight_list[t].append((row, w_f)) token_row_weight_list[t].append((row, w_f))
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed) # q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed)
N = self.w13_weight.size(1) N = self.w13_weight.size(1)
gateup_output = torch.empty((pad_M, N * 16), device=device, dtype=torch.bfloat16) gateup_output = torch.empty((pad_M, N * 16), device=device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_contig_asm( m_grouped_w8a8_gemm_nt_contig_asm(
(q_a1_all, q_a1_scale), (hidden_states_packed, hidden_states_scale_packed),
(self.w13_weight, self.w13_weight_scale), (self.w13_weight, self.w13_weight_scale),
gateup_output, gateup_output,
m_indices, m_indices,
) )
del hidden_states_packed, hidden_states_scale_packed
q_a2_all, q_a2_scale = fuse_silu_mul_quant(gateup_output) q_a2_all, q_a2_scale = fuse_silu_mul_quant(gateup_output)
down_output = torch.empty((pad_M, K), device=device, dtype=torch.bfloat16) down_output = torch.empty((pad_M, K), device=device, dtype=torch.bfloat16)
down_output = m_grouped_w8a8_gemm_nt_contig_asm( down_output = m_grouped_w8a8_gemm_nt_contig_asm(
...@@ -806,17 +811,10 @@ class DeepEPMoE(EPMoE): ...@@ -806,17 +811,10 @@ class DeepEPMoE(EPMoE):
) )
result = torch.zeros((M, K), device=device, dtype=down_output.dtype) result = torch.zeros((M, K), device=device, dtype=down_output.dtype)
for t in range(M): for t in range(M):
pairs = token_row_weight_list[t] for (row, w) in token_row_weight_list[t]:
if not pairs: result[t].addcmul_(down_output[row].float(), w)
continue
acc = None return result.to(down_output.dtype)
for (row, w) in pairs:
vec = down_output[row].float()
weighted = vec * w
acc = weighted if acc is None else (acc + weighted)
result[t] = acc.to(result.dtype)
return result
def forward_deepgemm_contiguous( def forward_deepgemm_contiguous(
self, self,
......
...@@ -4,7 +4,6 @@ import logging ...@@ -4,7 +4,6 @@ import logging
from contextlib import nullcontext from contextlib import nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.distributed import get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.dp_attention import get_is_extend_in_batch from sglang.srt.layers.dp_attention import get_is_extend_in_batch
...@@ -30,7 +29,7 @@ from sglang.srt.utils import ( ...@@ -30,7 +29,7 @@ from sglang.srt.utils import (
is_npu, is_npu,
load_json_config, load_json_config,
) )
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
_is_npu = is_npu() _is_npu = is_npu()
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -369,6 +368,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -369,6 +368,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
# scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, # scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, # scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# ) # )
hidden_states = per_token_quant_int8(hidden_states)
previous_event = Buffer.capture() if self.async_finish else None previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_ids, topk_weights, previous_event return hidden_states, topk_ids, topk_weights, previous_event
...@@ -441,19 +441,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -441,19 +441,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expert_alignment=1, expert_alignment=1,
config=DeepEPConfig.get_instance().normal_dispatch_config, config=DeepEPConfig.get_instance().normal_dispatch_config,
) )
if self.quant_config.get("quant_method") == "slimquant_w4a8_marlin": get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
self.rank_expert_offset= get_moe_expert_parallel_rank() * ( self.num_experts // get_moe_expert_parallel_world_size()) num_recv_tokens_per_expert,
recv_topk_ids = torch.where( num_tokens_per_rank=num_tokens_per_rank,
recv_topk_ids == -1, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
self.num_experts - 1 if self.rank_expert_offset == 0 else 0, num_tokens_per_expert=num_tokens_per_expert,
recv_topk_ids + self.rank_expert_offset) )
else:
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
return ( return (
recv_x, recv_x,
recv_topk_ids, recv_topk_ids,
......
...@@ -213,7 +213,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -213,7 +213,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
): ):
self.moe_runner_config = moe_runner_config self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
@torch._dynamo.disable() # TODO: 性能优化需lmslim/lightop配合 @torch._dynamo.disable() # TODO: 性能优化需lmslim/lightop配合
def apply( def apply(
self, self,
...@@ -307,7 +307,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -307,7 +307,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# use_nn_moe: Optional[bool] = False, # use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None, # routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False, # use_fused_gate: Optional[bool] = False,
# **_ # **_
# ) -> torch.Tensor: # ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported) # from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts # from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -351,8 +351,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -351,8 +351,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# a2_scale=layer.w2_input_scale, # a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe, # use_nn_moe=use_nn_moe,
# ) # )
def apply_ep(self, def apply_ep(self,
x: torch.Tensor, x: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -392,6 +392,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -392,6 +392,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment