Commit 53cbb488 authored by yiqa's avatar yiqa
Browse files

normal模式下适配w8a8

parent 3fc0ce15
......@@ -2,7 +2,7 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from collections import defaultdict
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch
......@@ -626,10 +626,11 @@ class DeepEPMoE(EPMoE):
device = hidden_states.device
M = hidden_states.shape[0]
active_experts = set()
token_expert_pos = [None] * M
K = hidden_states.shape[1]
topk = topk_idx.shape[1]
active_experts = set()
token_expert_pos = [None] * M
for t in range(M):
lst = []
for pos in range(topk):
......@@ -644,13 +645,30 @@ class DeepEPMoE(EPMoE):
if num_active == 0:
return hidden_states.bfloat16()
block = 256
pad_M = block * num_active
K = hidden_states.shape[1]
counts = defaultdict(int)
for t in range(M):
for (e, pos) in token_expert_pos[t]:
counts[e] += 1
per_expert_block = {}
for e in active_experts:
cnt = counts.get(e, 0)
if cnt <= 0:
per_expert_block[e] = 0
else:
needed = ((cnt + 256 - 1) // 256) * 256 # next multiple of 256
per_expert_block[e] = max(256, needed)
expert_slot_offset = {}
offset = 0
for e in active_experts:
expert_slot_offset[e] = offset
offset += per_expert_block[e]
pad_M = offset
hidden_states_packed = torch.zeros((pad_M, K), device=device, dtype=hidden_states.dtype)
m_indices = torch.full((pad_M,), -1, device=device, dtype=torch.int32)
expert_slot_offset = {e: i * block for i, e in enumerate(active_experts)}
slot_counters = {e: 0 for e in active_experts}
token_row_weight_list = {t: [] for t in range(M)}
......@@ -658,8 +676,8 @@ class DeepEPMoE(EPMoE):
for (e, pos) in token_expert_pos[t]:
start = expert_slot_offset[e]
slot = slot_counters[e]
if slot >= block:
raise RuntimeError(f"Too many tokens for expert {e} (>block).")
if slot >= per_expert_block[e]:
raise RuntimeError(f"Internal error: expert {e} slot {slot} >= block {per_expert_block[e]}")
row = start + slot
hidden_states_packed[row] = hidden_states[t]
m_indices[row] = int(e)
......@@ -672,16 +690,13 @@ class DeepEPMoE(EPMoE):
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed)
N = self.w13_weight.size(1)
gateup_output = torch.empty((pad_M, N * 16), device=device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_contig_asm(
(q_a1_all, q_a1_scale),
(self.w13_weight, self.w13_weight_scale),
gateup_output,
m_indices,
)
q_a2_all, q_a2_scale = fuse_silu_mul_quant(gateup_output)
down_output = torch.empty((pad_M, K), device=device, dtype=torch.bfloat16)
down_output = m_grouped_w8a8_gemm_nt_contig_asm(
(q_a2_all, q_a2_scale),
......@@ -690,7 +705,6 @@ class DeepEPMoE(EPMoE):
m_indices,
)
result = torch.zeros((M, K), device=device, dtype=down_output.dtype)
for t in range(M):
pairs = token_row_weight_list[t]
if not pairs:
......@@ -699,10 +713,7 @@ class DeepEPMoE(EPMoE):
for (row, w) in pairs:
vec = down_output[row].float()
weighted = vec * w
if acc is None:
acc = weighted
else:
acc = acc + weighted
acc = weighted if acc is None else (acc + weighted)
result[t] = acc.to(result.dtype)
return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment