# The stride variables represent how much to increase the ptr by when moving by 1
# The stride variables represent how much to increase the ptr by when
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# by to get the element one row down (A has M rows).
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_am,
stride_ak,
stride_ak,
stride_be,
stride_be,
...
@@ -50,17 +51,30 @@ def fused_moe_kernel(
...
@@ -50,17 +51,30 @@ def fused_moe_kernel(
compute_type:tl.constexpr,
compute_type:tl.constexpr,
):
):
"""
"""
Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices.
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token.
- A: The input tensor representing tokens with shape (*, K), where '*' can
- B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension.
be any shape representing batches and K is the feature dimension of
- C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated,
each token.
and N is the output feature dimension.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
- sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to.
the number of experts, K is the input feature dimension, and N is
- expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A.
the output feature dimension.
This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids`
- C: The output cache tensor with shape (M, topk, N), where M is the
by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert.
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.