Commit 1871c26c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev-12.23-yql' into 'v0.11.0-dev'

适配gptq/awq的triton moe算子

See merge request dcutoolkit/deeplearing/vllm!313
parents 10349d37 25e8b412
...@@ -370,51 +370,52 @@ def fused_moe_kernel_awq( ...@@ -370,51 +370,52 @@ def fused_moe_kernel_awq(
@triton.jit @triton.jit
def fused_moe_kernel_gptq_awq( def fused_moe_kernel_gptq_awq(
# Pointers to matrices # Pointers to matrices
a_ptr, a_ptr,
b_ptr, b_ptr,
c_ptr, c_ptr,
b_scale_ptr, b_scale_ptr,
b_zp_ptr, b_zp_ptr,
topk_weights_ptr, topk_weights_ptr,
sorted_token_ids_ptr, sorted_token_ids_ptr,
expert_ids_ptr, expert_ids_ptr,
num_tokens_post_padded_ptr, num_tokens_post_padded_ptr,
# Matrix dimensions # Matrix dimensions
N: tl.constexpr, N: tl.constexpr,
K: tl.constexpr, K: tl.constexpr,
EM, EM,
num_valid_tokens, num_valid_tokens,
# The stride variables represent how much to increase the ptr by when # The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is # moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down # how much to increase `a_ptr` by to get the element one row down
# (A has M rows). # (A has M rows).
stride_am, stride_am,
stride_ak, stride_ak,
stride_be, stride_be,
stride_bk, stride_bk,
stride_bn, stride_bn,
stride_cm, stride_cm,
stride_cn, stride_cn,
stride_bse, stride_bse,
stride_bsk, stride_bsk,
stride_bsn, stride_bsn,
stride_bze, stride_bze,
stride_bzk, stride_bzk,
stride_bzn, stride_bzn,
block_k_diviable: tl.constexpr, block_k_diviable: tl.constexpr,
group_size: tl.constexpr, group_size: tl.constexpr,
# Meta-parameters # Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr, top_k: tl.constexpr,
compute_type: tl.constexpr, compute_type: tl.constexpr,
has_zp: tl.constexpr, has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr, use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr): use_int8_w8a16: tl.constexpr,
):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices. token and expert matrices.
...@@ -463,26 +464,50 @@ def fused_moe_kernel_gptq_awq( ...@@ -463,26 +464,50 @@ def fused_moe_kernel_gptq_awq(
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N if off_experts == -1:
offs_k = tl.arange(0, BLOCK_SIZE_K) # -----------------------------------------------------------
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + # Write back zeros to the output when the expert is not
offs_k[None, :] * stride_ak) # in the current expert parallel rank.
write_zeros_to_output(
c_ptr,
stride_cm,
stride_cn,
pid_n,
N,
offs_token,
token_mask,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
compute_type,
)
return
off_experts = tl.load(expert_ids_ptr + pid_m) offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
if use_int4_w4a16: if use_int4_w4a16:
b_ptrs = b_ptr + off_experts * stride_be + \ b_ptrs = (
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] // 2) * stride_bk
+ offs_bn[None, :] * stride_bn
)
b_shifter = (offs_k[:, None] % 2) * 4 b_shifter = (offs_k[:, None] % 2) * 4
elif use_int8_w8a16: elif use_int8_w8a16:
b_ptrs = b_ptr + off_experts * stride_be + \ b_ptrs = (
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn b_ptr
+ off_experts * stride_be
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
)
if not has_zp and use_int4_w4a16: if not has_zp and use_int4_w4a16:
b_zp_num = 8 b_zp_num = 8
...@@ -508,33 +533,43 @@ def fused_moe_kernel_gptq_awq( ...@@ -508,33 +533,43 @@ def fused_moe_kernel_gptq_awq(
k_mask = None k_mask = None
k_other = None k_other = None
a = tl.load(a_ptrs, a = tl.load(
mask=token_mask[:, None] & a_ptrs,
(offs_k[None, :] < K - k * BLOCK_SIZE_K), mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0) other=0.0,
)
b = tl.load(b_ptrs) b = tl.load(b_ptrs)
if use_int4_w4a16: if use_int4_w4a16:
b = (b >> b_shifter) & 0xF b = (b >> b_shifter) & 0xF
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ b_scale_ptrs = (
offs_bn[None, :] * stride_bsn + \ b_scale_ptr
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + off_experts * stride_bse
+ offs_bn[None, :] * stride_bsn
+ ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
)
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
b_scale = b_scale.to(tl.float32) b_scale = b_scale.to(tl.float32)
if has_zp and use_int4_w4a16: if has_zp and use_int4_w4a16:
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ b_zp_ptrs = (
(offs_bn[None, :] // 2) * stride_bzn + \ b_zp_ptr
offs_k_true * stride_bzk + off_experts * stride_bze
+ (offs_bn[None, :] // 2) * stride_bzn
+ offs_k_true * stride_bzk
)
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
b_zp = ((b_zp >> b_zp_shifter) & 0xF) b_zp = (b_zp >> b_zp_shifter) & 0xF
b_zp = b_zp.to(tl.float32) b_zp = b_zp.to(tl.float32)
elif has_zp and use_int8_w8a16: elif has_zp and use_int8_w8a16:
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ b_zp_ptrs = (
offs_bn[None, :] * stride_bzn + \ b_zp_ptr
offs_k_true * stride_bzk + off_experts * stride_bze
+ offs_bn[None, :] * stride_bzn
+ offs_k_true * stride_bzk
)
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
b_zp = b_zp.to(tl.float32) b_zp = b_zp.to(tl.float32)
...@@ -553,17 +588,14 @@ def fused_moe_kernel_gptq_awq( ...@@ -553,17 +588,14 @@ def fused_moe_kernel_gptq_awq(
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT: if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type) accumulator = accumulator.to(compute_type)
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N) c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
......
...@@ -1787,6 +1787,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): ...@@ -1787,6 +1787,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
**_
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.fused_experts is None assert self.fused_experts is None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment