Commit f9a784a7 authored by yangql's avatar yangql
Browse files

更新curr_topk_ids

parent cd87548a
......@@ -384,26 +384,17 @@ def fused_moe_kernel_gptq_awq(
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
offs_token, token_mask, BLOCK_SIZE_M,
BLOCK_SIZE_N, compute_type)
return
offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if use_int4_w4a16:
b_ptrs = b_ptr + off_experts * stride_be + \
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \
stride_bn
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
b_shifter = (offs_k[:, None] % 2) * 4
elif use_int8_w8a16:
b_ptrs = b_ptr + off_experts * stride_be + \
......@@ -443,8 +434,7 @@ def fused_moe_kernel_gptq_awq(
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
offs_bn[None, :] * stride_bsn + \
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \
stride_bsk
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
b_scale = b_scale.to(tl.float32)
......@@ -716,6 +706,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B_scale: Optional[torch.Tensor],
B_zp: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor],
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
......@@ -1709,6 +1700,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w1_scale,
w1_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
......@@ -1769,6 +1761,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_scale,
w2_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment