Unverified Commit 2fce82b7 authored by hx's avatar hx Committed by GitHub
Browse files

[MoE][PyTorch] Add mask-based MoE permutation (#1373)



* add mask-based moe permutation

* change moe_chunk_permute to moe_sort_chunks_by_indices

* fix __all__ in pytorch/permutation.py

* fix func/var names and typos; update tols in UT

---------
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent c2c3d540
...@@ -52,6 +52,8 @@ pyTorch ...@@ -52,6 +52,8 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute .. autoapifunction:: transformer_engine.pytorch.moe_unpermute
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index
.. autoapifunction:: transformer_engine.pytorch.initialize_ub .. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub .. autoapifunction:: transformer_engine.pytorch.destroy_ub
This diff is collapsed.
...@@ -74,7 +74,11 @@ from transformer_engine.pytorch.attention import DotProductAttention ...@@ -74,7 +74,11 @@ from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.transformer import TransformerLayer from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute from transformer_engine.pytorch.permutation import (
moe_permute,
moe_unpermute,
moe_sort_chunks_by_index,
)
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.fp8 import fp8_model_init
from transformer_engine.pytorch.graph import make_graphed_callables from transformer_engine.pytorch.graph import make_graphed_callables
......
...@@ -2,24 +2,26 @@ ...@@ -2,24 +2,26 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Linear API""" """MoE Permutaion API"""
import warnings import warnings
from typing import Tuple from typing import Tuple
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from .constants import TE_DType import transformer_engine.pytorch.triton.permutation as triton_permutation
from .float8_tensor import Float8Tensor from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.float8_tensor import Float8Tensor
__all__ = [ __all__ = [
"moe_permute", "moe_permute",
"moe_unpermute", "moe_unpermute",
"moe_sort_chunks_by_index",
] ]
class _moe_permute(torch.autograd.Function): class _moe_permute_index_map(torch.autograd.Function):
"""functional Permute""" """functional Permute with index router map"""
workspace = None workspace = None
max_expanded_token_num = 0 max_expanded_token_num = 0
...@@ -28,7 +30,7 @@ class _moe_permute(torch.autograd.Function): ...@@ -28,7 +30,7 @@ class _moe_permute(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
indices: torch.Tensor, index: torch.Tensor,
num_out_tokens: int, num_out_tokens: int,
max_token_num: int, max_token_num: int,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -39,9 +41,9 @@ class _moe_permute(torch.autograd.Function): ...@@ -39,9 +41,9 @@ class _moe_permute(torch.autograd.Function):
# Device check # Device check
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
assert indices.is_cuda, "TransformerEngine needs CUDA." assert index.is_cuda, "TransformerEngine needs CUDA."
# Shape check # Shape check
assert inp.size(0) == indices.size(0), "Permute not possible" assert inp.size(0) == index.size(0), "Permute not possible"
# Data type check # Data type check
fp8 = isinstance(inp, Float8Tensor) fp8 = isinstance(inp, Float8Tensor)
...@@ -51,27 +53,27 @@ class _moe_permute(torch.autograd.Function): ...@@ -51,27 +53,27 @@ class _moe_permute(torch.autograd.Function):
inp = inp._data inp = inp._data
else: else:
dtype = TE_DType[inp.dtype] dtype = TE_DType[inp.dtype]
if indices.dtype != torch.int32: if index.dtype != torch.int32:
warnings.warn( warnings.warn(
f"The data type of the input `indices` of Permute is {indices.dtype}! " f"The data type of the input `index` of Permute is {index.dtype}! "
"The recommended type is torch.int32." "The recommended type is torch.int32."
) )
indices = indices.to(torch.int32) index = index.to(torch.int32)
topK = indices.size(1) topK = index.size(1)
input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK
if _moe_permute.max_expanded_token_num < input_max_expanded_token_num: if _moe_permute_index_map.max_expanded_token_num < input_max_expanded_token_num:
_moe_permute.max_expanded_token_num = input_max_expanded_token_num _moe_permute_index_map.max_expanded_token_num = input_max_expanded_token_num
_moe_permute.workspace = [] _moe_permute_index_map.workspace = []
permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd( permuted_act, row_id_map, _moe_permute_index_map.workspace = tex.moe_permute_fwd(
inp, inp,
dtype, dtype,
indices, index,
num_out_tokens, num_out_tokens,
_moe_permute.workspace, _moe_permute_index_map.workspace,
_moe_permute.max_expanded_token_num, _moe_permute_index_map.max_expanded_token_num,
) )
if fp8: if fp8:
...@@ -80,8 +82,8 @@ class _moe_permute(torch.autograd.Function): ...@@ -80,8 +82,8 @@ class _moe_permute(torch.autograd.Function):
) )
ctx.row_id_map = row_id_map ctx.row_id_map = row_id_map
ctx.num_tokens = indices.size(0) ctx.num_tokens = index.size(0)
ctx.topK = indices.size(1) ctx.topK = index.size(1)
ctx.fp8 = fp8 ctx.fp8 = fp8
return permuted_act, row_id_map return permuted_act, row_id_map
...@@ -122,8 +124,8 @@ class _moe_permute(torch.autograd.Function): ...@@ -122,8 +124,8 @@ class _moe_permute(torch.autograd.Function):
return act_grad, None, None, None return act_grad, None, None, None
class _moe_unpermute(torch.autograd.Function): class _moe_unpermute_index_map(torch.autograd.Function):
"""functional Unpermute""" """functional Unpermute with index router map"""
@staticmethod @staticmethod
def forward( def forward(
...@@ -225,21 +227,238 @@ class _moe_unpermute(torch.autograd.Function): ...@@ -225,21 +227,238 @@ class _moe_unpermute(torch.autograd.Function):
return act_grad, None, prob_grad return act_grad, None, prob_grad
class _moe_permute_mask_map(torch.autograd.Function):
"""functional Permute with mask router map"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
routing_map: torch.Tensor,
num_out_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
if not inp.numel():
return inp, torch.tensor([], device=inp.device)
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert routing_map.is_cuda, "TransformerEngine needs CUDA."
assert inp.size(0) == routing_map.size(0), "Permute not possible"
num_tokens, hidden_size = inp.size()
num_experts = routing_map.size(1)
assert (
num_out_tokens is not None
), "num_out_tokens must be provided to the fused permute function."
row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)
fp8 = isinstance(inp, Float8Tensor)
if fp8:
fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
inp = inp._data
output = triton_permutation.permute_with_mask_map(
inp,
row_id_map,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
if fp8:
output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv)
ctx.save_for_backward(row_id_map)
ctx.num_experts = num_experts
ctx.num_tokens = num_tokens
ctx.hidden_size = hidden_size
return output, row_id_map
@staticmethod
def backward(
ctx,
permuted_act_grad: torch.Tensor,
_,
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
if not permuted_act_grad.numel():
return permuted_act_grad, None, None
act_grad = None
if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors
fp8 = isinstance(permuted_act_grad, Float8Tensor)
if fp8:
fp8_dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
permuted_act_grad = permuted_act_grad._data
else:
fp8_dtype = None
act_grad = triton_permutation.unpermute_with_mask_map(
permuted_act_grad,
row_id_map,
None,
ctx.num_tokens,
ctx.num_experts,
ctx.hidden_size,
fp8_dtype,
)
if fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv * ctx.num_experts,
)
return act_grad, None, None
class _moe_unpermute_mask_map(torch.autograd.Function):
"""functional Unpermute with mask router map"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
row_id_map: torch.Tensor,
probs: torch.Tensor,
restore_shape: torch.Size,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if not inp.numel():
ctx.probs = probs
return inp
if restore_shape is None:
restore_shape = inp.shape
num_tokens, hidden_size = restore_shape
num_experts = row_id_map.size(0)
with_probs = probs is not None
if with_probs:
assert probs.is_cuda, "TransformerEngine needs CUDA."
if probs.dtype != torch.float32:
warnings.warn(
f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
"The recommended type is torch.float32."
)
probs = probs.to(torch.float32)
# Device check
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
fp8 = isinstance(inp, Float8Tensor)
if fp8:
fp8_dtype = inp._fp8_dtype
if not with_probs:
fp8_scale_inv = inp._scale_inv * num_experts
else:
fp8_scale_inv = inp._scale_inv
inp = inp._data
else:
fp8_dtype = None
unpermuted_output = triton_permutation.unpermute_with_mask_map(
inp,
row_id_map,
probs,
num_tokens,
num_experts,
hidden_size,
fp8_dtype=fp8_dtype,
)
if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
)
if with_probs:
ctx.save_for_backward(inp, row_id_map, probs)
else:
ctx.save_for_backward(row_id_map)
ctx.num_experts = num_experts
ctx.num_tokens = num_tokens
ctx.num_permuted_tokens = inp.size(0)
ctx.hidden_size = hidden_size
ctx.with_probs = with_probs
return unpermuted_output
@staticmethod
def backward(ctx, unpermuted_act_grad):
# pylint: disable=missing-function-docstring
if not unpermuted_act_grad.numel():
return unpermuted_act_grad, None, ctx.probs, None
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
if ctx.with_probs:
fwd_input, row_id_map, probs = ctx.saved_tensors
else:
(row_id_map,) = ctx.saved_tensors
fp8 = isinstance(unpermuted_act_grad, Float8Tensor)
if fp8:
fp8_dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
unpermuted_act_grad = unpermuted_act_grad._data
else:
fp8_dtype = None
if ctx.with_probs:
act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_probs(
unpermuted_act_grad,
row_id_map,
fwd_input,
probs,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
ctx.hidden_size,
fp8_dtype,
)
else:
act_grad = triton_permutation.permute_with_mask_map(
unpermuted_act_grad,
row_id_map,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
ctx.hidden_size,
)
if fp8:
act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
)
if not ctx.needs_input_grad[2]:
probs_grad = None
return act_grad, None, probs_grad, None
def moe_permute( def moe_permute(
inp: torch.Tensor, inp: torch.Tensor,
indices: torch.Tensor, routing_map: torch.Tensor,
num_out_tokens: int = -1, num_out_tokens: int = -1,
max_token_num: int = -1, max_token_num: int = -1,
map_type: str = "mask",
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Permute the tokens based on the indices. Token with the same index will be grouped together. Permute the tokens based on the routing_map. Token with the same index will be grouped together.
Tokens with the same designated expert will be grouped together.
The routing_map indicates which experts were selected by each token.
Parameters Parameters
---------- ----------
inp: torch.Tensor inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
indices: torch.Tensor routing_map: torch.Tensor
The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'. The token to expert mapping tensor.
If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
The values in it are the routed expert indices.
num_out_tokens: int, default = -1 num_out_tokens: int, default = -1
The effective output token count, representing the number of tokens not dropped. The effective output token count, representing the number of tokens not dropped.
By default, set to '-1', meaning no tokens are dropped. By default, set to '-1', meaning no tokens are dropped.
...@@ -247,14 +466,23 @@ def moe_permute( ...@@ -247,14 +466,23 @@ def moe_permute(
The maximum number of tokens, used for workspace allocation. The maximum number of tokens, used for workspace allocation.
By default, set to '-1', meaning the calculation of the size of workspace is By default, set to '-1', meaning the calculation of the size of workspace is
automatically taken over by the operator. automatically taken over by the operator.
map_type: str, default = 'mask'
Type of the routing map tensor.
Options are: 'mask', 'index'.
""" """
return _moe_permute.apply(inp, indices, num_out_tokens, max_token_num) if map_type == "index":
return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
if map_type == "mask":
return _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens)
raise ValueError("map_type should be one of 'mask' or 'index'")
def moe_unpermute( def moe_unpermute(
inp: torch.Tensor, inp: torch.Tensor,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
probs: torch.Tensor = None, probs: torch.Tensor = None,
restore_shape: torch.Tensor = None,
map_type: str = "mask",
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
...@@ -271,5 +499,109 @@ def moe_unpermute( ...@@ -271,5 +499,109 @@ def moe_unpermute(
The tensor of probabilities corresponding to the permuted tokens. If provided, The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities. the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
restore_shape: torch.Tensor
The output shape after the unpermute operation.
map_type: str, default = 'mask'
Type of the routing map tensor. Should be the same as the value passed to moe_permute.
Options are: 'mask', 'index'.
"""
if map_type == "index":
return _moe_unpermute_index_map.apply(inp, row_id_map, probs)
if map_type == "mask":
return _moe_unpermute_mask_map.apply(inp, row_id_map, probs, restore_shape)
raise ValueError("map_type should be one of 'mask' or 'index'")
class _moe_chunk_sort(torch.autograd.Function):
"""functional MoE chunk permute"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
split_sizes: torch.Tensor,
sorted_idxs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
if not inp.numel():
return inp, torch.tensor([], device=inp.device)
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert split_sizes.is_cuda, "TransformerEngine needs CUDA."
assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA."
num_tokens, hidden_size = inp.shape
num_splits = split_sizes.size(0)
assert num_splits == sorted_idxs.size(0)
fp8 = isinstance(inp, Float8Tensor)
if fp8:
fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
inp = inp._data
output, row_id_map = triton_permutation.sort_chunks_by_idx(
inp,
split_sizes,
sorted_idxs,
num_tokens,
hidden_size,
num_splits,
)
if fp8:
output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv)
ctx.save_for_backward(row_id_map)
ctx.num_tokens = num_tokens
ctx.hidden_size = hidden_size
return output
@staticmethod
def backward(
ctx,
permuted_act_grad: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
if not permuted_act_grad.numel():
return permuted_act_grad, None, None
act_grad = None
if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors
fp8 = isinstance(permuted_act_grad, Float8Tensor)
if fp8:
fp8_dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
permuted_act_grad = permuted_act_grad._data
act_grad = triton_permutation.sort_chunks_by_map(
permuted_act_grad,
row_id_map,
ctx.num_tokens,
ctx.hidden_size,
)
if fp8:
act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
)
return act_grad, None, None
def moe_sort_chunks_by_index(
inp: torch.Tensor,
split_sizes: torch.Tensor,
sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Split and sort the input tensor based on the split_sizes and sorted indices.
The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
according to the sorted_indices.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
split_sizes: torch.Tensor
Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices: torch.Tensor
Chunk indices used to permute the chunks.
""" """
return _moe_unpermute.apply(inp, row_id_map, probs) return _moe_chunk_sort.apply(inp, split_sizes, sorted_index)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Kernels written with OpenAI Triton."""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Permutation kernels written with OpenAI Triton."""
from typing import Union
import torch
import triton
import triton.language as tl
from transformer_engine_torch import DType as TE_DType
@triton.jit
def _row_id_map_pass_1_kernel(
# pointers
routing_map_ptr,
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# strides
stride_routing_map_token,
stride_routing_map_expert,
# metas
BLOCK_SIZE: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
expert_token_mask = tl.load(
routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token,
mask=(offset < num_tokens),
other=0,
).to(tl.int64)
row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask
tl.store(
row_id_map_ptr + pid_m * num_tokens + offset,
row_id_within_token_block,
mask=offset < num_tokens,
)
n_tokens_per_block = tl.sum(expert_token_mask)
tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block)
@triton.jit
def _row_id_map_pass_2_kernel(
# pointers
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# metas
WORKSPACE_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
row_id_within_token_block = tl.load(
row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0
)
workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx)
row_id = tl.where(
row_id_within_token_block == 0,
-1,
row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1,
)
tl.store(
row_id_map_ptr + pid_m * num_tokens + offset,
row_id,
mask=(offset < num_tokens),
)
def make_row_id_map(
routing_map: torch.Tensor,
num_tokens: int,
num_experts: int,
):
# pylint: disable=missing-function-docstring
row_id_map = torch.empty((num_experts, num_tokens), dtype=torch.int64, device="cuda")
block_size = 256
grid = (num_experts, triton.cdiv(num_tokens, block_size))
workspace_tensor = torch.empty(grid, dtype=torch.int64, device="cuda")
# block cumsum
_row_id_map_pass_1_kernel[grid](
routing_map,
row_id_map,
workspace_tensor,
num_tokens,
routing_map.stride(0),
routing_map.stride(1),
block_size,
)
# cumsum all and process the mask
_row_id_map_pass_2_kernel[grid](
row_id_map,
workspace_tensor,
num_tokens,
triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)),
block_size,
)
return row_id_map
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit
def _permute_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
# sizes
num_tokens,
num_experts,
hidden_size,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
# metas
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
cur_pos = 0
while cur_pos < hidden_size:
cur_off = cur_pos + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size
input_off = pid * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
for expert_idx in range(num_experts):
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if dst_row != -1:
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
tl.store(output_ptr + output_off, inp, mask=mask)
cur_pos += BLOCK_SIZE
def permute_with_mask_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
hidden_size: int,
):
# pylint: disable=missing-function-docstring
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
grid = (num_tokens,)
_permute_kernel[grid](
inp,
output,
row_id_map,
num_tokens,
num_experts,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
)
return output
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit
def _unpermute_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
# sizes
num_tokens,
num_experts,
hidden_size,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_probs_expert,
# metas
WITH_PROBS: tl.constexpr,
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
if FP8_DTYPE == "e5m2":
compute_type = tl.float16
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
compute_type = tl.float16
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
compute_type = input_ptr.dtype.element_ty
assert FP8_DTYPE is None
pid = tl.program_id(0)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
for expert_idx in range(num_experts):
src_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if src_row != -1:
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True).to(compute_type)
if WITH_PROBS:
prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off).to(compute_type)
inp *= prob
accumulator += inp
if FP8_DTYPE is not None:
if not WITH_PROBS:
# Directly adding these value may cause overflow for fp8, we scale it here.
# The outside fp8_scale_inv is also scaled in the meantime.
accumulator /= num_experts
accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True)
output_off = pid * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask)
current_start += BLOCK_SIZE
def unpermute_with_mask_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
probs: Union[torch.Tensor, None],
num_tokens: int,
num_experts: int,
hidden_size: int,
fp8_dtype: TE_DType,
):
# pylint: disable=missing-function-docstring
if fp8_dtype == TE_DType.kFloat8E5M2:
fp8_dtype = "e5m2"
elif fp8_dtype == TE_DType.kFloat8E4M3:
fp8_dtype = "e4m3"
else:
fp8_dtype = None
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
grid = (num_tokens,)
_unpermute_kernel[grid](
inp,
output,
row_id_map,
probs,
num_tokens,
num_experts,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
probs.stride(0) if probs is not None else None,
probs.stride(1) if probs is not None else None,
WITH_PROBS=probs is not None,
FP8_DTYPE=fp8_dtype,
)
return output
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit
def _unpermute_bwd_with_probs_kernel(
# pointers
fwd_output_grad_ptr,
fwd_input_grad_ptr,
fwd_input_ptr,
probs_ptr,
probs_grad_ptr,
row_id_map_ptr,
# sizes
num_tokens,
num_experts,
hidden_size,
# strides
stride_fwd_output_grad_token,
stride_fwd_output_grad_hidden,
stride_fwd_input_grad_token,
stride_fwd_input_grad_hidden,
stride_fwd_input_token,
stride_fwd_input_hidden,
stride_probs_token,
stride_probs_expert,
stride_probs_grad_token,
stride_probs_grad_expert,
# metas
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
if FP8_DTYPE == "e5m2":
compute_type = tl.float16
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
compute_type = tl.float16
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
compute_type = fwd_output_grad_ptr.dtype.element_ty
assert FP8_DTYPE is None
pid = tl.program_id(0)
for expert_idx in range(num_experts):
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if dst_row != -1:
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_off = (
pid * stride_fwd_output_grad_token
+ current_offset * stride_fwd_output_grad_hidden
)
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True).to(compute_type)
probs_off = pid * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + probs_off).to(compute_type)
output = inp * prob
if FP8_DTYPE is not None:
output = output.to(data_type).to(pytorch_tensor_dtype, bitcast=True)
output_off = (
dst_row * stride_fwd_input_grad_token
+ current_offset * stride_fwd_input_grad_hidden
)
tl.store(fwd_input_grad_ptr + output_off, output, mask=mask)
fwd_input_off = (
dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
)
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
if FP8_DTYPE is not None:
fwd_input = fwd_input.to(data_type, bitcast=True)
prob_grad_accum += fwd_input.to(tl.float32) * inp.to(tl.float32)
current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum)
probs_grad_off = pid * stride_probs_grad_token + expert_idx * stride_probs_grad_expert
tl.store(probs_grad_ptr + probs_grad_off, probs_grad)
else:
probs_grad_off = pid * stride_probs_grad_token + expert_idx * stride_probs_grad_expert
tl.store(probs_grad_ptr + probs_grad_off, 0.0)
def unpermute_with_mask_map_bwd_with_probs(
fwd_output_grad: torch.Tensor,
row_id_map: torch.Tensor,
fwd_input: torch.Tensor,
probs: torch.Tensor,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
hidden_size: int,
fp8_dtype: TE_DType,
):
# pylint: disable=missing-function-docstring
if fp8_dtype == TE_DType.kFloat8E5M2:
fp8_dtype = "e5m2"
elif fp8_dtype == TE_DType.kFloat8E4M3:
fp8_dtype = "e4m3"
else:
fp8_dtype = None
act_grad = torch.empty(
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
)
probs_grad = torch.empty((num_tokens, num_experts), dtype=probs.dtype, device="cuda")
grid = (num_tokens,)
_unpermute_bwd_with_probs_kernel[grid](
fwd_output_grad,
act_grad,
fwd_input,
probs,
probs_grad,
row_id_map,
num_tokens,
num_experts,
hidden_size,
fwd_output_grad.stride(0),
fwd_output_grad.stride(1),
act_grad.stride(0),
act_grad.stride(1),
fwd_input.stride(0),
fwd_input.stride(1),
probs.stride(0),
probs.stride(1),
probs_grad.stride(0),
probs_grad.stride(1),
fp8_dtype,
)
return act_grad, probs_grad
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit
def _sort_chunks_by_idxs_kernel(
# pointers
input_ptr,
split_sizes_ptr,
sorted_indices_ptr,
output_ptr,
dst_rows_ptr,
# sizes
num_splits,
hidden_size,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
# metas
IDX_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
load_split_offset = tl.arange(0, IDX_LOAD_WIDTH)
sorted_indices = tl.load(
sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits
)
# get chunk idx of the current token in the input tensor
input_chunk_idx = -1
in_chunk_offset = tl.zeros([], dtype=tl.int64)
acc_chunk_sizes = tl.zeros([], dtype=tl.int64)
cursor = 0
while cursor < num_splits:
cur_chunk_size = tl.load(split_sizes_ptr + cursor).to(tl.int64)
acc_chunk_sizes += cur_chunk_size
if input_chunk_idx == -1 and acc_chunk_sizes > pid:
input_chunk_idx = cursor
in_chunk_offset = pid - (acc_chunk_sizes - cur_chunk_size)
cursor += 1
# get chunk idx of the current token in the output tensor
output_chunk_idx = 0
cursor = 0
while cursor < num_splits:
cur_input_idx = tl.load(sorted_indices_ptr + cursor)
if cur_input_idx == input_chunk_idx:
output_chunk_idx = cursor
cursor += 1
# make row_id_map
output_split_sizes = tl.load(
split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits
).to(tl.int64)
output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
tl.store(dst_rows_ptr + pid, dst_row)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_offsets = pid * stride_input_token + current_offset * stride_input_hidden
output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden
inp = tl.load(input_ptr + input_offsets, mask=mask)
tl.store(output_ptr + output_offsets, inp, mask=mask)
current_start += BLOCK_SIZE
def sort_chunks_by_idx(
inp: torch.Tensor,
split_sizes: torch.Tensor,
sorted_indices: torch.Tensor,
num_tokens: int,
hidden_size: int,
num_splits: int,
):
# pylint: disable=missing-function-docstring
row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda")
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
grid = (num_tokens,)
_sort_chunks_by_idxs_kernel[grid](
inp,
split_sizes,
sorted_indices,
output,
row_id_map,
num_splits,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
triton.next_power_of_2(num_splits),
)
return output, row_id_map
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit
def _sort_chunks_by_map(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
# sizes
hidden_size,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
# metas
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
dst_row = tl.load(row_id_map_ptr + pid)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_offsets = dst_row * stride_input_token + current_offset * stride_input_hidden
output_offsets = pid * stride_output_token + current_offset * stride_output_hidden
inp = tl.load(input_ptr + input_offsets, mask=mask)
tl.store(output_ptr + output_offsets, inp, mask=mask)
current_start += BLOCK_SIZE
def sort_chunks_by_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
num_tokens: int,
hidden_size: int,
):
# pylint: disable=missing-function-docstring
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
grid = (num_tokens,)
_sort_chunks_by_map[grid](
inp,
output,
row_id_map,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
)
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment