Commit 0a71d6b1 authored by yiqa's avatar yiqa Committed by lizhigong
Browse files

V0.5.4 dev yiqa

parent 078de197
......@@ -3,6 +3,7 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from collections import defaultdict
from sglang.srt.distributed import get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch
......@@ -119,8 +120,8 @@ def fuse_silu_mul_quant_ep_wrapper(
topk,
expect_m
)
def fuse_silu_mul_quant_ep_fake(
input: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None,
......@@ -695,6 +696,11 @@ class DeepEPMoE(EPMoE):
if all_tokens <= 0:
return hidden_states.bfloat16()
rank_expert_offset = get_moe_expert_parallel_rank() * (self.num_experts // get_moe_expert_parallel_world_size())
topk_idx = torch.where(
topk_idx == -1,
self.num_experts - 1 if rank_expert_offset == 0 else 0,
topk_idx + rank_expert_offset)
expert_output = self.quant_method.apply_ep(
x=hidden_states,
w1=self.w13_weight,
......@@ -708,6 +714,7 @@ class DeepEPMoE(EPMoE):
use_nn_moe=False,
w1_scale=self.w13_weight_scale,
w2_scale=self.w2_weight_scale,
a1_scale=hidden_states_scale,
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor,
)
return expert_output
......@@ -740,10 +747,9 @@ class DeepEPMoE(EPMoE):
active_experts.add(e)
token_expert_pos[t] = lst
active_experts = sorted(list(active_experts))
num_active = len(active_experts)
if num_active == 0:
if not active_experts:
return hidden_states.bfloat16()
active_experts = sorted(list(active_experts))
counts = defaultdict(int)
for t in range(M):
......@@ -752,12 +758,9 @@ class DeepEPMoE(EPMoE):
per_expert_block = {}
for e in active_experts:
cnt = counts.get(e, 0)
if cnt <= 0:
per_expert_block[e] = 0
else:
needed = ((cnt + 256 - 1) // 256) * 256 # next multiple of 256
per_expert_block[e] = max(256, needed)
cnt = counts[e]
needed = ((cnt + 255) // 256) * 256 # same as ceil(cnt/256)*256
per_expert_block[e] = max(256, needed)
expert_slot_offset = {}
offset = 0
......@@ -766,7 +769,8 @@ class DeepEPMoE(EPMoE):
offset += per_expert_block[e]
pad_M = offset
hidden_states_packed = torch.zeros((pad_M, K), device=device, dtype=hidden_states.dtype)
hidden_states_packed = torch.empty((pad_M, K), device=device, dtype=hidden_states.dtype)
hidden_states_scale_packed = torch.empty((pad_M,), device=device, dtype=hidden_states_scale.dtype)
m_indices = torch.full((pad_M,), -1, device=device, dtype=torch.int32)
slot_counters = {e: 0 for e in active_experts}
......@@ -776,26 +780,27 @@ class DeepEPMoE(EPMoE):
for (e, pos) in token_expert_pos[t]:
start = expert_slot_offset[e]
slot = slot_counters[e]
if slot >= per_expert_block[e]:
raise RuntimeError(f"Internal error: expert {e} slot {slot} >= block {per_expert_block[e]}")
row = start + slot
hidden_states_packed[row] = hidden_states[t]
m_indices[row] = int(e)
hidden_states_scale_packed[row] = hidden_states_scale[t]
m_indices[row] = e
slot_counters[e] += 1
w = topk_weights[t, pos].to(device=device)
# record weight (as float32 on device)
w = topk_weights[t, pos]
w_f = w.float() if w.dtype != torch.float32 else w
token_row_weight_list[t].append((row, w_f))
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed)
# q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed)
N = self.w13_weight.size(1)
gateup_output = torch.empty((pad_M, N * 16), device=device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_contig_asm(
(q_a1_all, q_a1_scale),
(hidden_states_packed, hidden_states_scale_packed),
(self.w13_weight, self.w13_weight_scale),
gateup_output,
m_indices,
)
del hidden_states_packed, hidden_states_scale_packed
q_a2_all, q_a2_scale = fuse_silu_mul_quant(gateup_output)
down_output = torch.empty((pad_M, K), device=device, dtype=torch.bfloat16)
down_output = m_grouped_w8a8_gemm_nt_contig_asm(
......@@ -806,17 +811,10 @@ class DeepEPMoE(EPMoE):
)
result = torch.zeros((M, K), device=device, dtype=down_output.dtype)
for t in range(M):
pairs = token_row_weight_list[t]
if not pairs:
continue
acc = None
for (row, w) in pairs:
vec = down_output[row].float()
weighted = vec * w
acc = weighted if acc is None else (acc + weighted)
result[t] = acc.to(result.dtype)
return result
for (row, w) in token_row_weight_list[t]:
result[t].addcmul_(down_output[row].float(), w)
return result.to(down_output.dtype)
def forward_deepgemm_contiguous(
self,
......
......@@ -4,7 +4,6 @@ import logging
from contextlib import nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.distributed import get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
......@@ -30,7 +29,7 @@ from sglang.srt.utils import (
is_npu,
load_json_config,
)
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
_is_npu = is_npu()
if TYPE_CHECKING:
......@@ -369,6 +368,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
# scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
# )
hidden_states = per_token_quant_int8(hidden_states)
previous_event = Buffer.capture() if self.async_finish else None
return hidden_states, topk_ids, topk_weights, previous_event
......@@ -441,19 +441,12 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expert_alignment=1,
config=DeepEPConfig.get_instance().normal_dispatch_config,
)
if self.quant_config.get("quant_method") == "slimquant_w4a8_marlin":
self.rank_expert_offset= get_moe_expert_parallel_rank() * ( self.num_experts // get_moe_expert_parallel_world_size())
recv_topk_ids = torch.where(
recv_topk_ids == -1,
self.num_experts - 1 if self.rank_expert_offset == 0 else 0,
recv_topk_ids + self.rank_expert_offset)
else:
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
return (
recv_x,
recv_topk_ids,
......
......@@ -213,7 +213,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
):
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
@torch._dynamo.disable() # TODO: 性能优化需lmslim/lightop配合
def apply(
self,
......@@ -307,7 +307,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False,
# **_
# **_
# ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
......@@ -351,8 +351,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe,
# )
def apply_ep(self,
def apply_ep(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
......@@ -392,6 +392,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
global_num_experts=global_num_experts,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment