"src/vscode:/vscode.git/clone" did not exist on "cdda94f412413eb1fe86b17dbb2b1e8d108aa59e"
Commit 3fc0ce15 authored by yiqa's avatar yiqa
Browse files

normal模式下适配w8a8

parent a50eb0e6
...@@ -40,7 +40,7 @@ if TYPE_CHECKING: ...@@ -40,7 +40,7 @@ if TYPE_CHECKING:
DeepEPNormalOutput, DeepEPNormalOutput,
DispatchOutput, DispatchOutput,
) )
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep, m_grouped_w8a8_gemm_nt_masked from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep, m_grouped_w8a8_gemm_nt_masked, m_grouped_w8a8_gemm_nt_contig_asm, fuse_silu_mul_quant
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
_is_hip = is_hip() _is_hip = is_hip()
...@@ -506,6 +506,8 @@ class DeepEPMoE(EPMoE): ...@@ -506,6 +506,8 @@ class DeepEPMoE(EPMoE):
return self.forward_deepgemm_contiguous(dispatch_output) return self.forward_deepgemm_contiguous(dispatch_output)
elif self.use_w4a8_marlin: elif self.use_w4a8_marlin:
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output) return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
elif self.use_w8a8_marlin:
return self.forward_groupgemm_w8a8_marlin_contiguous(dispatch_output)
else: else:
raise ValueError( raise ValueError(
f"Dispatch output is not supported" f"Dispatch output is not supported"
...@@ -610,6 +612,100 @@ class DeepEPMoE(EPMoE): ...@@ -610,6 +612,100 @@ class DeepEPMoE(EPMoE):
) )
return expert_output return expert_output
def forward_groupgemm_w8a8_marlin_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states, hidden_states_scale, topk_idx, topk_weights, num_recv_tokens_per_expert = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states.bfloat16()
device = hidden_states.device
M = hidden_states.shape[0]
active_experts = set()
token_expert_pos = [None] * M
topk = topk_idx.shape[1]
for t in range(M):
lst = []
for pos in range(topk):
e = int(topk_idx[t, pos].item())
if e >= 0:
lst.append((e, pos))
active_experts.add(e)
token_expert_pos[t] = lst
active_experts = sorted(list(active_experts))
num_active = len(active_experts)
if num_active == 0:
return hidden_states.bfloat16()
block = 256
pad_M = block * num_active
K = hidden_states.shape[1]
hidden_states_packed = torch.zeros((pad_M, K), device=device, dtype=hidden_states.dtype)
m_indices = torch.full((pad_M,), -1, device=device, dtype=torch.int32)
expert_slot_offset = {e: i * block for i, e in enumerate(active_experts)}
slot_counters = {e: 0 for e in active_experts}
token_row_weight_list = {t: [] for t in range(M)}
for t in range(M):
for (e, pos) in token_expert_pos[t]:
start = expert_slot_offset[e]
slot = slot_counters[e]
if slot >= block:
raise RuntimeError(f"Too many tokens for expert {e} (>block).")
row = start + slot
hidden_states_packed[row] = hidden_states[t]
m_indices[row] = int(e)
slot_counters[e] += 1
w = topk_weights[t, pos].to(device=device)
w_f = w.float() if w.dtype != torch.float32 else w
token_row_weight_list[t].append((row, w_f))
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed)
N = self.w13_weight.size(1)
gateup_output = torch.empty((pad_M, N * 16), device=device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_contig_asm(
(q_a1_all, q_a1_scale),
(self.w13_weight, self.w13_weight_scale),
gateup_output,
m_indices,
)
q_a2_all, q_a2_scale = fuse_silu_mul_quant(gateup_output)
down_output = torch.empty((pad_M, K), device=device, dtype=torch.bfloat16)
down_output = m_grouped_w8a8_gemm_nt_contig_asm(
(q_a2_all, q_a2_scale),
(self.w2_weight, self.w2_weight_scale),
down_output,
m_indices,
)
result = torch.zeros((M, K), device=device, dtype=down_output.dtype)
for t in range(M):
pairs = token_row_weight_list[t]
if not pairs:
continue
acc = None
for (row, w) in pairs:
vec = down_output[row].float()
weighted = vec * w
if acc is None:
acc = weighted
else:
acc = acc + weighted
result[t] = acc.to(result.dtype)
return result
def forward_deepgemm_contiguous( def forward_deepgemm_contiguous(
self, self,
......
...@@ -441,18 +441,19 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -441,18 +441,19 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
expert_alignment=1, expert_alignment=1,
config=DeepEPConfig.get_instance().normal_dispatch_config, config=DeepEPConfig.get_instance().normal_dispatch_config,
) )
# get_global_expert_distribution_recorder().on_deepep_dispatch_normal( if self.quant_config.get("quant_method") == "slimquant_w4a8_marlin":
# num_recv_tokens_per_expert,
# num_tokens_per_rank=num_tokens_per_rank,
# num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
# num_tokens_per_expert=num_tokens_per_expert,
# )
self.rank_expert_offset= get_moe_expert_parallel_rank() * ( self.num_experts // get_moe_expert_parallel_world_size()) self.rank_expert_offset= get_moe_expert_parallel_rank() * ( self.num_experts // get_moe_expert_parallel_world_size())
recv_topk_ids = torch.where( recv_topk_ids = torch.where(
recv_topk_ids == -1, recv_topk_ids == -1,
self.num_experts - 1 if self.rank_expert_offset == 0 else 0, self.num_experts - 1 if self.rank_expert_offset == 0 else 0,
recv_topk_ids + self.rank_expert_offset) recv_topk_ids + self.rank_expert_offset)
else:
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
return ( return (
recv_x, recv_x,
recv_topk_ids, recv_topk_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment