"vscode:/vscode.git/clone" did not exist on "b84c426a8c48c25db6a4a1a14860e845347db1c1"
Commit 25e8b412 authored by yangql's avatar yangql
Browse files

适配gptq/awq的triton moe算子

parent 10349d37
...@@ -414,7 +414,8 @@ def fused_moe_kernel_gptq_awq( ...@@ -414,7 +414,8 @@ def fused_moe_kernel_gptq_awq(
compute_type: tl.constexpr, compute_type: tl.constexpr,
has_zp: tl.constexpr, has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr, use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr): use_int8_w8a16: tl.constexpr,
):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices. token and expert matrices.
...@@ -463,26 +464,50 @@ def fused_moe_kernel_gptq_awq( ...@@ -463,26 +464,50 @@ def fused_moe_kernel_gptq_awq(
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N if off_experts == -1:
offs_k = tl.arange(0, BLOCK_SIZE_K) # -----------------------------------------------------------
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + # Write back zeros to the output when the expert is not
offs_k[None, :] * stride_ak) # in the current expert parallel rank.
write_zeros_to_output(
c_ptr,
stride_cm,
stride_cn,
pid_n,
N,
offs_token,
token_mask,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
compute_type,
)
return
off_experts = tl.load(expert_ids_ptr + pid_m) offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
if use_int4_w4a16: if use_int4_w4a16:
b_ptrs = b_ptr + off_experts * stride_be + \ b_ptrs = (
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] // 2) * stride_bk
+ offs_bn[None, :] * stride_bn
)
b_shifter = (offs_k[:, None] % 2) * 4 b_shifter = (offs_k[:, None] % 2) * 4
elif use_int8_w8a16: elif use_int8_w8a16:
b_ptrs = b_ptr + off_experts * stride_be + \ b_ptrs = (
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn b_ptr
+ off_experts * stride_be
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
)
if not has_zp and use_int4_w4a16: if not has_zp and use_int4_w4a16:
b_zp_num = 8 b_zp_num = 8
...@@ -508,33 +533,43 @@ def fused_moe_kernel_gptq_awq( ...@@ -508,33 +533,43 @@ def fused_moe_kernel_gptq_awq(
k_mask = None k_mask = None
k_other = None k_other = None
a = tl.load(a_ptrs, a = tl.load(
mask=token_mask[:, None] & a_ptrs,
(offs_k[None, :] < K - k * BLOCK_SIZE_K), mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0) other=0.0,
)
b = tl.load(b_ptrs) b = tl.load(b_ptrs)
if use_int4_w4a16: if use_int4_w4a16:
b = (b >> b_shifter) & 0xF b = (b >> b_shifter) & 0xF
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ b_scale_ptrs = (
offs_bn[None, :] * stride_bsn + \ b_scale_ptr
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + off_experts * stride_bse
+ offs_bn[None, :] * stride_bsn
+ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
)
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
b_scale = b_scale.to(tl.float32) b_scale = b_scale.to(tl.float32)
if has_zp and use_int4_w4a16: if has_zp and use_int4_w4a16:
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ b_zp_ptrs = (
(offs_bn[None, :] // 2) * stride_bzn + \ b_zp_ptr
offs_k_true * stride_bzk + off_experts * stride_bze
+ (offs_bn[None, :] // 2) * stride_bzn
+ offs_k_true * stride_bzk
)
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
b_zp = ((b_zp >> b_zp_shifter) & 0xF) b_zp = (b_zp >> b_zp_shifter) & 0xF
b_zp = b_zp.to(tl.float32) b_zp = b_zp.to(tl.float32)
elif has_zp and use_int8_w8a16: elif has_zp and use_int8_w8a16:
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ b_zp_ptrs = (
offs_bn[None, :] * stride_bzn + \ b_zp_ptr
offs_k_true * stride_bzk + off_experts * stride_bze
+ offs_bn[None, :] * stride_bzn
+ offs_k_true * stride_bzk
)
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
b_zp = b_zp.to(tl.float32) b_zp = b_zp.to(tl.float32)
...@@ -553,17 +588,14 @@ def fused_moe_kernel_gptq_awq( ...@@ -553,17 +588,14 @@ def fused_moe_kernel_gptq_awq(
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT: if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type) accumulator = accumulator.to(compute_type)
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N) c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
......
...@@ -1787,6 +1787,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1787,6 +1787,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
**_
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment