Commit 25e8b412 authored by yangql's avatar yangql
Browse files

适配gptq/awq的triton moe算子

parent 10349d37
......@@ -414,7 +414,8 @@ def fused_moe_kernel_gptq_awq(
compute_type: tl.constexpr,
has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr):
use_int8_w8a16: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
......@@ -463,26 +464,50 @@ def fused_moe_kernel_gptq_awq(
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
tl.int64)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(
c_ptr,
stride_cm,
stride_cn,
pid_n,
N,
offs_token,
token_mask,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
compute_type,
)
return
off_experts = tl.load(expert_ids_ptr + pid_m)
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
if use_int4_w4a16:
b_ptrs = b_ptr + off_experts * stride_be + \
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] // 2) * stride_bk
+ offs_bn[None, :] * stride_bn
)
b_shifter = (offs_k[:, None] % 2) * 4
elif use_int8_w8a16:
b_ptrs = b_ptr + off_experts * stride_be + \
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
)
if not has_zp and use_int4_w4a16:
b_zp_num = 8
......@@ -508,33 +533,43 @@ def fused_moe_kernel_gptq_awq(
k_mask = None
k_other = None
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs)
if use_int4_w4a16:
b = (b >> b_shifter) & 0xF
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
offs_bn[None, :] * stride_bsn + \
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
b_scale_ptrs = (
b_scale_ptr
+ off_experts * stride_bse
+ offs_bn[None, :] * stride_bsn
+ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
)
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
b_scale = b_scale.to(tl.float32)
if has_zp and use_int4_w4a16:
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
(offs_bn[None, :] // 2) * stride_bzn + \
offs_k_true * stride_bzk
b_zp_ptrs = (
b_zp_ptr
+ off_experts * stride_bze
+ (offs_bn[None, :] // 2) * stride_bzn
+ offs_k_true * stride_bzk
)
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
b_zp = ((b_zp >> b_zp_shifter) & 0xF)
b_zp = (b_zp >> b_zp_shifter) & 0xF
b_zp = b_zp.to(tl.float32)
elif has_zp and use_int8_w8a16:
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
offs_bn[None, :] * stride_bzn + \
offs_k_true * stride_bzk
b_zp_ptrs = (
b_zp_ptr
+ off_experts * stride_bze
+ offs_bn[None, :] * stride_bzn
+ offs_k_true * stride_bzk
)
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
b_zp = b_zp.to(tl.float32)
......@@ -553,17 +588,14 @@ def fused_moe_kernel_gptq_awq(
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
......
......@@ -1787,6 +1787,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
**_
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment