Commit 53cbb488 authored by yiqa's avatar yiqa
Browse files

normal模式下适配w8a8

parent 3fc0ce15
...@@ -2,7 +2,7 @@ from __future__ import annotations ...@@ -2,7 +2,7 @@ from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from collections import defaultdict
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch import torch
...@@ -626,10 +626,11 @@ class DeepEPMoE(EPMoE): ...@@ -626,10 +626,11 @@ class DeepEPMoE(EPMoE):
device = hidden_states.device device = hidden_states.device
M = hidden_states.shape[0] M = hidden_states.shape[0]
active_experts = set() K = hidden_states.shape[1]
token_expert_pos = [None] * M
topk = topk_idx.shape[1] topk = topk_idx.shape[1]
active_experts = set()
token_expert_pos = [None] * M
for t in range(M): for t in range(M):
lst = [] lst = []
for pos in range(topk): for pos in range(topk):
...@@ -644,13 +645,30 @@ class DeepEPMoE(EPMoE): ...@@ -644,13 +645,30 @@ class DeepEPMoE(EPMoE):
if num_active == 0: if num_active == 0:
return hidden_states.bfloat16() return hidden_states.bfloat16()
block = 256 counts = defaultdict(int)
pad_M = block * num_active for t in range(M):
K = hidden_states.shape[1] for (e, pos) in token_expert_pos[t]:
counts[e] += 1
per_expert_block = {}
for e in active_experts:
cnt = counts.get(e, 0)
if cnt <= 0:
per_expert_block[e] = 0
else:
needed = ((cnt + 256 - 1) // 256) * 256 # next multiple of 256
per_expert_block[e] = max(256, needed)
expert_slot_offset = {}
offset = 0
for e in active_experts:
expert_slot_offset[e] = offset
offset += per_expert_block[e]
pad_M = offset
hidden_states_packed = torch.zeros((pad_M, K), device=device, dtype=hidden_states.dtype) hidden_states_packed = torch.zeros((pad_M, K), device=device, dtype=hidden_states.dtype)
m_indices = torch.full((pad_M,), -1, device=device, dtype=torch.int32) m_indices = torch.full((pad_M,), -1, device=device, dtype=torch.int32)
expert_slot_offset = {e: i * block for i, e in enumerate(active_experts)}
slot_counters = {e: 0 for e in active_experts} slot_counters = {e: 0 for e in active_experts}
token_row_weight_list = {t: [] for t in range(M)} token_row_weight_list = {t: [] for t in range(M)}
...@@ -658,8 +676,8 @@ class DeepEPMoE(EPMoE): ...@@ -658,8 +676,8 @@ class DeepEPMoE(EPMoE):
for (e, pos) in token_expert_pos[t]: for (e, pos) in token_expert_pos[t]:
start = expert_slot_offset[e] start = expert_slot_offset[e]
slot = slot_counters[e] slot = slot_counters[e]
if slot >= block: if slot >= per_expert_block[e]:
raise RuntimeError(f"Too many tokens for expert {e} (>block).") raise RuntimeError(f"Internal error: expert {e} slot {slot} >= block {per_expert_block[e]}")
row = start + slot row = start + slot
hidden_states_packed[row] = hidden_states[t] hidden_states_packed[row] = hidden_states[t]
m_indices[row] = int(e) m_indices[row] = int(e)
...@@ -672,16 +690,13 @@ class DeepEPMoE(EPMoE): ...@@ -672,16 +690,13 @@ class DeepEPMoE(EPMoE):
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed) q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed)
N = self.w13_weight.size(1) N = self.w13_weight.size(1)
gateup_output = torch.empty((pad_M, N * 16), device=device, dtype=torch.bfloat16) gateup_output = torch.empty((pad_M, N * 16), device=device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_contig_asm( m_grouped_w8a8_gemm_nt_contig_asm(
(q_a1_all, q_a1_scale), (q_a1_all, q_a1_scale),
(self.w13_weight, self.w13_weight_scale), (self.w13_weight, self.w13_weight_scale),
gateup_output, gateup_output,
m_indices, m_indices,
) )
q_a2_all, q_a2_scale = fuse_silu_mul_quant(gateup_output) q_a2_all, q_a2_scale = fuse_silu_mul_quant(gateup_output)
down_output = torch.empty((pad_M, K), device=device, dtype=torch.bfloat16) down_output = torch.empty((pad_M, K), device=device, dtype=torch.bfloat16)
down_output = m_grouped_w8a8_gemm_nt_contig_asm( down_output = m_grouped_w8a8_gemm_nt_contig_asm(
(q_a2_all, q_a2_scale), (q_a2_all, q_a2_scale),
...@@ -690,7 +705,6 @@ class DeepEPMoE(EPMoE): ...@@ -690,7 +705,6 @@ class DeepEPMoE(EPMoE):
m_indices, m_indices,
) )
result = torch.zeros((M, K), device=device, dtype=down_output.dtype) result = torch.zeros((M, K), device=device, dtype=down_output.dtype)
for t in range(M): for t in range(M):
pairs = token_row_weight_list[t] pairs = token_row_weight_list[t]
if not pairs: if not pairs:
...@@ -699,10 +713,7 @@ class DeepEPMoE(EPMoE): ...@@ -699,10 +713,7 @@ class DeepEPMoE(EPMoE):
for (row, w) in pairs: for (row, w) in pairs:
vec = down_output[row].float() vec = down_output[row].float()
weighted = vec * w weighted = vec * w
if acc is None: acc = weighted if acc is None else (acc + weighted)
acc = weighted
else:
acc = acc + weighted
result[t] = acc.to(result.dtype) result[t] = acc.to(result.dtype)
return result return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment