Commit 7211d743 authored by yiqa's avatar yiqa Committed by lizhigong
Browse files

w8a8适配continues算子

parent 5533c538
......@@ -2,7 +2,7 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from collections import defaultdict
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch
......@@ -41,7 +41,7 @@ if TYPE_CHECKING:
DeepEPNormalOutput,
DispatchOutput,
)
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep, m_grouped_w8a8_gemm_nt_masked
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep, m_grouped_w8a8_gemm_nt_masked, m_grouped_w8a8_gemm_nt_contig_asm, fuse_silu_mul_quant
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
_is_hip = is_hip()
......@@ -606,6 +606,8 @@ class DeepEPMoE(EPMoE):
return self.forward_deepgemm_contiguous(dispatch_output)
elif self.use_w4a8_marlin:
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
elif self.use_w8a8_marlin:
return self.forward_groupgemm_w8a8_marlin_contiguous(dispatch_output)
else:
raise ValueError(
f"Dispatch output is not supported"
......@@ -710,6 +712,111 @@ class DeepEPMoE(EPMoE):
)
return expert_output
def forward_groupgemm_w8a8_marlin_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states, hidden_states_scale, topk_idx, topk_weights, num_recv_tokens_per_expert = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states.bfloat16()
device = hidden_states.device
M = hidden_states.shape[0]
K = hidden_states.shape[1]
topk = topk_idx.shape[1]
active_experts = set()
token_expert_pos = [None] * M
for t in range(M):
lst = []
for pos in range(topk):
e = int(topk_idx[t, pos].item())
if e >= 0:
lst.append((e, pos))
active_experts.add(e)
token_expert_pos[t] = lst
active_experts = sorted(list(active_experts))
num_active = len(active_experts)
if num_active == 0:
return hidden_states.bfloat16()
counts = defaultdict(int)
for t in range(M):
for (e, pos) in token_expert_pos[t]:
counts[e] += 1
per_expert_block = {}
for e in active_experts:
cnt = counts.get(e, 0)
if cnt <= 0:
per_expert_block[e] = 0
else:
needed = ((cnt + 256 - 1) // 256) * 256 # next multiple of 256
per_expert_block[e] = max(256, needed)
expert_slot_offset = {}
offset = 0
for e in active_experts:
expert_slot_offset[e] = offset
offset += per_expert_block[e]
pad_M = offset
hidden_states_packed = torch.zeros((pad_M, K), device=device, dtype=hidden_states.dtype)
m_indices = torch.full((pad_M,), -1, device=device, dtype=torch.int32)
slot_counters = {e: 0 for e in active_experts}
token_row_weight_list = {t: [] for t in range(M)}
for t in range(M):
for (e, pos) in token_expert_pos[t]:
start = expert_slot_offset[e]
slot = slot_counters[e]
if slot >= per_expert_block[e]:
raise RuntimeError(f"Internal error: expert {e} slot {slot} >= block {per_expert_block[e]}")
row = start + slot
hidden_states_packed[row] = hidden_states[t]
m_indices[row] = int(e)
slot_counters[e] += 1
w = topk_weights[t, pos].to(device=device)
w_f = w.float() if w.dtype != torch.float32 else w
token_row_weight_list[t].append((row, w_f))
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed)
N = self.w13_weight.size(1)
gateup_output = torch.empty((pad_M, N * 16), device=device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_contig_asm(
(q_a1_all, q_a1_scale),
(self.w13_weight, self.w13_weight_scale),
gateup_output,
m_indices,
)
q_a2_all, q_a2_scale = fuse_silu_mul_quant(gateup_output)
down_output = torch.empty((pad_M, K), device=device, dtype=torch.bfloat16)
down_output = m_grouped_w8a8_gemm_nt_contig_asm(
(q_a2_all, q_a2_scale),
(self.w2_weight, self.w2_weight_scale),
down_output,
m_indices,
)
result = torch.zeros((M, K), device=device, dtype=down_output.dtype)
for t in range(M):
pairs = token_row_weight_list[t]
if not pairs:
continue
acc = None
for (row, w) in pairs:
vec = down_output[row].float()
weighted = vec * w
acc = weighted if acc is None else (acc + weighted)
result[t] = acc.to(result.dtype)
return result
def forward_deepgemm_contiguous(
self,
......@@ -900,7 +1007,7 @@ class DeepEPMoE(EPMoE):
# base shapes
num_groups, m, k = hidden_states.size()
expected_m = min(m, expected_m)
expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
......@@ -949,7 +1056,7 @@ class DeepEPMoE(EPMoE):
assert self.moe_runner_config.activation == "silu"
# base shapes
num_groups, m, k = hidden_states.size()
expected_m = min(m, expected_m)
expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
......
......@@ -441,18 +441,19 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expert_alignment=1,
config=DeepEPConfig.get_instance().normal_dispatch_config,
)
# get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
# num_recv_tokens_per_expert,
# num_tokens_per_rank=num_tokens_per_rank,
# num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
# num_tokens_per_expert=num_tokens_per_expert,
# )
self.rank_expert_offset= get_moe_expert_parallel_rank() * ( self.num_experts // get_moe_expert_parallel_world_size())
recv_topk_ids = torch.where(
recv_topk_ids == -1,
self.num_experts - 1 if self.rank_expert_offset == 0 else 0,
recv_topk_ids + self.rank_expert_offset)
if self.quant_config.get("quant_method") == "slimquant_w4a8_marlin":
self.rank_expert_offset= get_moe_expert_parallel_rank() * ( self.num_experts // get_moe_expert_parallel_world_size())
recv_topk_ids = torch.where(
recv_topk_ids == -1,
self.num_experts - 1 if self.rank_expert_offset == 0 else 0,
recv_topk_ids + self.rank_expert_offset)
else:
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
return (
recv_x,
recv_topk_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment