Unverified Commit 2fce82b7 authored by hx's avatar hx Committed by GitHub
Browse files

[MoE][PyTorch] Add mask-based MoE permutation (#1373)



* add mask-based moe permutation

* change moe_chunk_permute to moe_sort_chunks_by_indices

* fix __all__ in pytorch/permutation.py

* fix func/var names and typos; update tols in UT

---------
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent c2c3d540
......@@ -52,6 +52,8 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
......@@ -2,11 +2,17 @@
#
# See LICENSE for license information.
import random
import torch
import pytest
from typing import Dict, List
from transformer_engine.pytorch import moe_permute as te_permute, moe_unpermute as te_unpermute
from transformer_engine.pytorch import (
moe_permute as te_permute,
moe_unpermute as te_unpermute,
moe_sort_chunks_by_index as te_sort_chunks_by_index,
)
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor
......@@ -18,7 +24,7 @@ torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def pytorch_permute(tokens, indices, num_out_tokens: int = None):
def pytorch_permute_index_map(tokens, indices, num_out_tokens: int = None):
"""
Permute the tokens based on the indices. Token with the same index will be grouped together.
The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately.
......@@ -50,7 +56,7 @@ def pytorch_permute(tokens, indices, num_out_tokens: int = None):
return permuted_tokens, sorted_indices
def pytorch_unpermute(
def pytorch_unpermute_index_map(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
probs: torch.Tensor = None,
......@@ -95,6 +101,86 @@ def pytorch_unpermute(
return unpermuted_tokens
def pytorch_permute_mask_map(tokens, routing_map):
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
The shape of mask is [tokens, num_experts], it indicates which experts were selected
by each token.
Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
"""
num_tokens, _ = tokens.shape
num_experts = routing_map.shape[1]
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)
# use the mapping to permute the tokens
permuted_input = tokens.index_select(0, sorted_indices)
return permuted_input, sorted_indices
def pytorch_unpermute_mask_map(
permuted_tokens: torch.Tensor,
sorted_indices: torch.Tensor,
restore_shape: torch.Size,
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
):
"""
Restore the original order of tokens after permutation. If probs are provided, it
will also apply them to the tokens before restoring the order.
Args:
permuted_tokens (torch.Tensor): The permuted token tensor.
sorted_indices (torch.Tensor): The indices used to sort the tokens.
restore_shape (torch.Size): The shape of the unpermuted tensor.
probs (torch.Tensor, optional): The unpermuted probs tensor,
routing_map (torch.Tensor, optional): Token to expert mapping, shape
[num_tokens, num_experts].
Returns:
torch.Tensor: The tokens restored to their original order.
"""
_, hidden = restore_shape
if probs is not None:
assert routing_map is not None, "Mask must be provided to permute the probs."
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)
# Create an output tensor filled with zeros
output_tokens = torch.zeros(
restore_shape, device=permuted_tokens.device, dtype=permuted_tokens.dtype
)
# Scatter add the permuted_input back to the original positions
output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens)
return output_tokens
def pytorch_sort_chunks_by_index(
input: torch.Tensor,
split_sizes: torch.Tensor,
sorted_idxs: torch.Tensor,
):
"""
Split and sort the input tensor based on the split_sizes and sorted indices.
return a tuple of (output, row_id_map). row_id_map is only used when fused=True.
"""
input = torch.split(input, split_sizes.tolist(), dim=0)
output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0)
return output
def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]:
"""Estimated tolerances for a datatype
......@@ -112,7 +198,7 @@ def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]:
raise ValueError(f"Unsuppored dtype ({te_dtype})")
def _test_permutation(
def _test_permutation_index_map(
te_dtype,
num_tokens,
num_expert,
......@@ -132,7 +218,8 @@ def _test_permutation(
num_out_tokens = num_tokens * topK
print(
f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
"index map:"
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
)
fp8 = False
......@@ -198,7 +285,7 @@ def _test_permutation(
# PyTorch Permutation
#
###################################################################################################################################
pytorch_permute_output, sorted_indices = pytorch_permute(
pytorch_permute_output, sorted_indices = pytorch_permute_index_map(
pytorch_permute_fwd_input, indices, num_out_tokens
)
pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True)
......@@ -206,9 +293,264 @@ def _test_permutation(
pytorch_unpermute_fwd_input = pytorch_permute_output.detach()
pytorch_unpermute_fwd_input.requires_grad_(True)
pytorch_unpermute_output = pytorch_unpermute(
pytorch_unpermute_output = pytorch_unpermute_index_map(
pytorch_unpermute_fwd_input, sorted_indices, probs=probs
)
pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# TE Permutation
#
###################################################################################################################################
te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach()
te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach()
te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, indices, num_out_tokens, map_type="index"
)
te_permute_output.backward(te_permute_bwd_input, retain_graph=True)
te_probs = None
if with_probs:
te_probs = probs.detach()
te_probs.requires_grad_(True)
te_unpermute_fwd_input = te_permute_output.detach()
te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()
te_unpermute_output = te_unpermute(
te_unpermute_fwd_input, row_id_map, te_probs, map_type="index"
)
te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
if fp8:
te_permute_output_ = te_permute_output.from_float8(torch.float32)
te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32)
te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32)
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32)
else:
te_permute_output_ = te_permute_output.float()
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
torch.testing.assert_close(
pytorch_permute_output.float(),
te_permute_output_,
msg=f"Mismatch in te_permute fwd",
)
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_fwd_input.grad.float(),
te_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if with_probs:
torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols
)
if not pytorch_permute_fwd_input.numel():
print("Empty pytorch_permute_fwd_input activation test passed.")
return
###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
def backward_wrapper(
act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False
):
# Set forward_input.grad to None to avoid grad accumulation.
if accumulate_grad == False:
for i in forward_input:
i.grad = None
return act.backward(backward_input, retain_graph=retain_graph)
if BENCHMARK:
t1 = perf_test_cuda_kernel(
lambda: pytorch_permute_index_map(pytorch_permute_fwd_input, indices, num_out_tokens)
)
t2 = perf_test_cuda_kernel(
lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens, map_type="index")
)
print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
pytorch_permute_output,
pytorch_permute_bwd_input,
forward_input=[pytorch_permute_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_permute_output,
te_permute_bwd_input,
forward_input=[te_permute_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: pytorch_unpermute_index_map(
pytorch_unpermute_fwd_input, sorted_indices, probs=probs
)
)
t2 = perf_test_cuda_kernel(
lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs, map_type="index")
)
print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
pytorch_unpermute_output,
pytorch_unpermute_bwd_input,
forward_input=(
[pytorch_unpermute_fwd_input, probs]
if with_probs
else [pytorch_unpermute_fwd_input]
),
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_unpermute_output,
te_unpermute_bwd_input,
forward_input=(
[te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input]
),
retain_graph=True,
accumulate_grad=False,
)
)
print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
def _test_permutation_mask_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
with_probs,
BENCHMARK=False,
):
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
if num_out_tokens == None:
num_out_tokens = num_tokens * topK
print(
"mask map:"
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
)
fp8 = False
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else:
pytest.skip("Invalid dtype.")
if fp8:
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
permute_bwd_input = torch.rand(
size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
unpermute_bwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
permute_fwd_input = Float8Tensor.to_float8(
permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)
)
permute_bwd_input = Float8Tensor.to_float8(
permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)
)
unpermute_bwd_input = Float8Tensor.to_float8(
unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)
)
pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16)
pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16)
pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16)
else:
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_fwd_input.requires_grad_(True)
restore_shape = pytorch_permute_fwd_input.shape
_tmp_tensor = torch.zeros((num_tokens * num_expert,))
_tmp_tensor[: int(num_out_tokens)] = 1.0
_tmp_idx = torch.randperm(num_tokens * num_expert)
routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()
probs = None
if with_probs:
probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
probs.requires_grad_(True)
###################################################################################################################################
#
# PyTorch Permutation
#
###################################################################################################################################
pytorch_permute_output, sorted_indices = pytorch_permute_mask_map(
pytorch_permute_fwd_input, routing_map
)
pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True)
pytorch_unpermute_fwd_input = pytorch_permute_output.detach()
pytorch_unpermute_fwd_input.requires_grad_(True)
pytorch_unpermute_output = pytorch_unpermute_mask_map(
pytorch_unpermute_fwd_input, sorted_indices, restore_shape, probs, routing_map
)
pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
......@@ -220,7 +562,9 @@ def _test_permutation(
te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach()
te_permute_output, row_id_map = te_permute(te_permute_fwd_input, indices, num_out_tokens)
te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, routing_map, num_out_tokens, map_type="mask"
)
te_permute_output.backward(te_permute_bwd_input, retain_graph=True)
te_probs = None
......@@ -231,7 +575,9 @@ def _test_permutation(
te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()
te_unpermute_output = te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs)
te_unpermute_output = te_unpermute(
te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask"
)
te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)
###################################################################################################################################
......@@ -256,6 +602,7 @@ def _test_permutation(
pytorch_permute_output.float(),
te_permute_output_,
msg=f"Mismatch in te_permute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
......@@ -300,10 +647,10 @@ def _test_permutation(
if BENCHMARK:
t1 = perf_test_cuda_kernel(
lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens)
lambda: pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map)
)
t2 = perf_test_cuda_kernel(
lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens)
lambda: te_permute(te_permute_fwd_input, routing_map, num_out_tokens, map_type="mask")
)
print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
......@@ -328,10 +675,14 @@ def _test_permutation(
print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs)
lambda: pytorch_unpermute_mask_map(
pytorch_unpermute_fwd_input, sorted_indices, restore_shape, probs, routing_map
)
)
t2 = perf_test_cuda_kernel(
lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs)
lambda: te_unpermute(
te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask"
)
)
print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
......@@ -362,6 +713,158 @@ def _test_permutation(
print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
def _test_moe_chunk_sort(
te_dtype,
num_tokens,
num_expert,
tp_size,
hidden_size,
BENCHMARK=False,
):
print(
"chunk permute:"
f" token:{num_tokens} hidden_size:{hidden_size} num_expert:{num_expert} tp_size:{tp_size} {te_dtype}"
)
fp8 = False
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
elif te_dtype == tex.DType.kFloat16:
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else:
pytest.skip("Invalid dtype.")
if fp8:
fwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda")
bwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda")
fwd_input = Float8Tensor.to_float8(
fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)
)
bwd_input = Float8Tensor.to_float8(
bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0)
)
pytorch_fwd_input = fwd_input.from_float8(torch.float16)
pytorch_bwd_input = bwd_input.from_float8(torch.float16)
else:
pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_fwd_input.requires_grad_(True)
_split_sizes = [0] * (num_expert * tp_size)
for _ in range(num_tokens):
idx = random.randint(0, num_expert * tp_size - 1)
_split_sizes[idx] += 1
split_sizes = torch.tensor(_split_sizes, dtype=torch.int32).ravel()
split_sizes_cuda = split_sizes.to(device="cuda")
_sorted_idxs = torch.arange(num_expert * tp_size, dtype=torch.int32)
sorted_idxs = _sorted_idxs.reshape(tp_size, num_expert).T.ravel()
sorted_idxs_cuda = sorted_idxs.to(device="cuda")
###################################################################################################################################
#
# PyTorch Permutation
#
###################################################################################################################################
pytorch_output = pytorch_sort_chunks_by_index(pytorch_fwd_input, split_sizes, sorted_idxs)
pytorch_output.backward(pytorch_bwd_input, retain_graph=True)
###################################################################################################################################
#
# TE Permutation
#
###################################################################################################################################
te_fwd_input = fwd_input if fp8 else pytorch_fwd_input.detach()
te_fwd_input.requires_grad_(True)
te_bwd_input = bwd_input if fp8 else pytorch_bwd_input.detach()
te_output = te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda)
te_output.backward(te_bwd_input, retain_graph=True)
###################################################################################################################################
#
# Results Check
#
###################################################################################################################################
tols = dtype_tols(te_dtype)
if fp8:
te_output_ = te_output.from_float8(torch.float32)
te_fwd_input_grad = te_fwd_input.grad.from_float8(torch.float32)
else:
te_output_ = te_output.float()
te_fwd_input_grad = te_fwd_input.grad.float()
torch.testing.assert_close(
pytorch_output.float(),
te_output_,
msg=f"Mismatch in te_permute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_fwd_input.grad.float(),
te_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
if not pytorch_fwd_input.numel():
print("Empty pytorch_fwd_input activation test passed.")
return
###################################################################################################################################
#
# Benchmark
#
###################################################################################################################################
def backward_wrapper(
act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False
):
# Set forward_input.grad to None to avoid grad accumulation.
if accumulate_grad == False:
for i in forward_input:
i.grad = None
return act.backward(backward_input, retain_graph=retain_graph)
if BENCHMARK:
t1 = perf_test_cuda_kernel(
lambda: pytorch_sort_chunks_by_index(pytorch_fwd_input, split_sizes, sorted_idxs)
)
t2 = perf_test_cuda_kernel(
lambda: te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda)
)
print(f"chunk sort\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
t1 = perf_test_cuda_kernel(
lambda: backward_wrapper(
pytorch_output,
pytorch_bwd_input,
forward_input=[pytorch_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_output,
te_bwd_input,
forward_input=[te_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"chunk sort\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
def perf_test_cuda_kernel(cuda_kernel_fn):
if torch.cuda.is_available():
# create CUDA event
......@@ -396,7 +899,7 @@ if is_bf16_compatible():
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation(
def test_permutation_index_map(
te_dtype,
num_tokens,
num_expert,
......@@ -407,7 +910,36 @@ def test_permutation(
with_probs = True
BENCHMARK = False
_test_permutation(
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_mask_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
with_probs = True
BENCHMARK = False
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
......@@ -430,7 +962,37 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_fp8(
def test_permutation_index_map_fp8(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
with_probs = True
BENCHMARK = False
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_mask_map_fp8(
te_dtype,
num_tokens,
num_expert,
......@@ -441,7 +1003,7 @@ def test_permutation_fp8(
with_probs = True
BENCHMARK = False
_test_permutation(
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
......@@ -457,7 +1019,7 @@ def test_permutation_fp8(
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_topk1_no_probs(
def test_permutation_index_map_topk1_no_probs(
te_dtype,
num_tokens,
num_expert,
......@@ -468,7 +1030,7 @@ def test_permutation_topk1_no_probs(
with_probs = False
BENCHMARK = False
_test_permutation(
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
......@@ -480,6 +1042,57 @@ def test_permutation_topk1_no_probs(
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_mask_map_topk1_no_probs(
te_dtype,
num_tokens,
num_expert,
hidden_size,
):
topK = 1
num_out_tokens = None
with_probs = False
BENCHMARK = False
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("tp_size", [1, 2, 8])
@pytest.mark.parametrize("hidden_size", [4096])
def test_chunk_permutation(
te_dtype,
num_tokens,
num_expert,
tp_size,
hidden_size,
):
BENCHMARK = False
_test_moe_chunk_sort(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
tp_size=tp_size,
hidden_size=hidden_size,
BENCHMARK=BENCHMARK,
)
def test_permutation_single_case():
print("GPU:", torch.cuda.get_device_name(0))
......@@ -497,7 +1110,18 @@ def test_permutation_single_case():
with_probs = True
Benchmark = True
_test_permutation(
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=Benchmark,
)
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
......@@ -508,6 +1132,15 @@ def test_permutation_single_case():
BENCHMARK=Benchmark,
)
_test_moe_chunk_sort(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
tp_size=4,
hidden_size=hidden_size,
BENCHMARK=Benchmark,
)
if __name__ == "__main__":
test_permutation_single_case()
......@@ -74,7 +74,11 @@ from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute
from transformer_engine.pytorch.permutation import (
moe_permute,
moe_unpermute,
moe_sort_chunks_by_index,
)
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformer_engine.pytorch.graph import make_graphed_callables
......
......@@ -2,24 +2,26 @@
#
# See LICENSE for license information.
"""Linear API"""
"""MoE Permutaion API"""
import warnings
from typing import Tuple
import torch
import transformer_engine_torch as tex
from .constants import TE_DType
from .float8_tensor import Float8Tensor
import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.float8_tensor import Float8Tensor
__all__ = [
"moe_permute",
"moe_unpermute",
"moe_sort_chunks_by_index",
]
class _moe_permute(torch.autograd.Function):
"""functional Permute"""
class _moe_permute_index_map(torch.autograd.Function):
"""functional Permute with index router map"""
workspace = None
max_expanded_token_num = 0
......@@ -28,7 +30,7 @@ class _moe_permute(torch.autograd.Function):
def forward(
ctx,
inp: torch.Tensor,
indices: torch.Tensor,
index: torch.Tensor,
num_out_tokens: int,
max_token_num: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
......@@ -39,9 +41,9 @@ class _moe_permute(torch.autograd.Function):
# Device check
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert indices.is_cuda, "TransformerEngine needs CUDA."
assert index.is_cuda, "TransformerEngine needs CUDA."
# Shape check
assert inp.size(0) == indices.size(0), "Permute not possible"
assert inp.size(0) == index.size(0), "Permute not possible"
# Data type check
fp8 = isinstance(inp, Float8Tensor)
......@@ -51,27 +53,27 @@ class _moe_permute(torch.autograd.Function):
inp = inp._data
else:
dtype = TE_DType[inp.dtype]
if indices.dtype != torch.int32:
if index.dtype != torch.int32:
warnings.warn(
f"The data type of the input `indices` of Permute is {indices.dtype}! "
f"The data type of the input `index` of Permute is {index.dtype}! "
"The recommended type is torch.int32."
)
indices = indices.to(torch.int32)
index = index.to(torch.int32)
topK = indices.size(1)
topK = index.size(1)
input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK
if _moe_permute.max_expanded_token_num < input_max_expanded_token_num:
_moe_permute.max_expanded_token_num = input_max_expanded_token_num
_moe_permute.workspace = []
if _moe_permute_index_map.max_expanded_token_num < input_max_expanded_token_num:
_moe_permute_index_map.max_expanded_token_num = input_max_expanded_token_num
_moe_permute_index_map.workspace = []
permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd(
permuted_act, row_id_map, _moe_permute_index_map.workspace = tex.moe_permute_fwd(
inp,
dtype,
indices,
index,
num_out_tokens,
_moe_permute.workspace,
_moe_permute.max_expanded_token_num,
_moe_permute_index_map.workspace,
_moe_permute_index_map.max_expanded_token_num,
)
if fp8:
......@@ -80,8 +82,8 @@ class _moe_permute(torch.autograd.Function):
)
ctx.row_id_map = row_id_map
ctx.num_tokens = indices.size(0)
ctx.topK = indices.size(1)
ctx.num_tokens = index.size(0)
ctx.topK = index.size(1)
ctx.fp8 = fp8
return permuted_act, row_id_map
......@@ -122,8 +124,8 @@ class _moe_permute(torch.autograd.Function):
return act_grad, None, None, None
class _moe_unpermute(torch.autograd.Function):
"""functional Unpermute"""
class _moe_unpermute_index_map(torch.autograd.Function):
"""functional Unpermute with index router map"""
@staticmethod
def forward(
......@@ -225,21 +227,238 @@ class _moe_unpermute(torch.autograd.Function):
return act_grad, None, prob_grad
class _moe_permute_mask_map(torch.autograd.Function):
"""functional Permute with mask router map"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
routing_map: torch.Tensor,
num_out_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
if not inp.numel():
return inp, torch.tensor([], device=inp.device)
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert routing_map.is_cuda, "TransformerEngine needs CUDA."
assert inp.size(0) == routing_map.size(0), "Permute not possible"
num_tokens, hidden_size = inp.size()
num_experts = routing_map.size(1)
assert (
num_out_tokens is not None
), "num_out_tokens must be provided to the fused permute function."
row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)
fp8 = isinstance(inp, Float8Tensor)
if fp8:
fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
inp = inp._data
output = triton_permutation.permute_with_mask_map(
inp,
row_id_map,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
if fp8:
output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv)
ctx.save_for_backward(row_id_map)
ctx.num_experts = num_experts
ctx.num_tokens = num_tokens
ctx.hidden_size = hidden_size
return output, row_id_map
@staticmethod
def backward(
ctx,
permuted_act_grad: torch.Tensor,
_,
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
if not permuted_act_grad.numel():
return permuted_act_grad, None, None
act_grad = None
if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors
fp8 = isinstance(permuted_act_grad, Float8Tensor)
if fp8:
fp8_dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
permuted_act_grad = permuted_act_grad._data
else:
fp8_dtype = None
act_grad = triton_permutation.unpermute_with_mask_map(
permuted_act_grad,
row_id_map,
None,
ctx.num_tokens,
ctx.num_experts,
ctx.hidden_size,
fp8_dtype,
)
if fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv * ctx.num_experts,
)
return act_grad, None, None
class _moe_unpermute_mask_map(torch.autograd.Function):
"""functional Unpermute with mask router map"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
row_id_map: torch.Tensor,
probs: torch.Tensor,
restore_shape: torch.Size,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if not inp.numel():
ctx.probs = probs
return inp
if restore_shape is None:
restore_shape = inp.shape
num_tokens, hidden_size = restore_shape
num_experts = row_id_map.size(0)
with_probs = probs is not None
if with_probs:
assert probs.is_cuda, "TransformerEngine needs CUDA."
if probs.dtype != torch.float32:
warnings.warn(
f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
"The recommended type is torch.float32."
)
probs = probs.to(torch.float32)
# Device check
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
fp8 = isinstance(inp, Float8Tensor)
if fp8:
fp8_dtype = inp._fp8_dtype
if not with_probs:
fp8_scale_inv = inp._scale_inv * num_experts
else:
fp8_scale_inv = inp._scale_inv
inp = inp._data
else:
fp8_dtype = None
unpermuted_output = triton_permutation.unpermute_with_mask_map(
inp,
row_id_map,
probs,
num_tokens,
num_experts,
hidden_size,
fp8_dtype=fp8_dtype,
)
if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
)
if with_probs:
ctx.save_for_backward(inp, row_id_map, probs)
else:
ctx.save_for_backward(row_id_map)
ctx.num_experts = num_experts
ctx.num_tokens = num_tokens
ctx.num_permuted_tokens = inp.size(0)
ctx.hidden_size = hidden_size
ctx.with_probs = with_probs
return unpermuted_output
@staticmethod
def backward(ctx, unpermuted_act_grad):
# pylint: disable=missing-function-docstring
if not unpermuted_act_grad.numel():
return unpermuted_act_grad, None, ctx.probs, None
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
if ctx.with_probs:
fwd_input, row_id_map, probs = ctx.saved_tensors
else:
(row_id_map,) = ctx.saved_tensors
fp8 = isinstance(unpermuted_act_grad, Float8Tensor)
if fp8:
fp8_dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
unpermuted_act_grad = unpermuted_act_grad._data
else:
fp8_dtype = None
if ctx.with_probs:
act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_probs(
unpermuted_act_grad,
row_id_map,
fwd_input,
probs,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
ctx.hidden_size,
fp8_dtype,
)
else:
act_grad = triton_permutation.permute_with_mask_map(
unpermuted_act_grad,
row_id_map,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
ctx.hidden_size,
)
if fp8:
act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
)
if not ctx.needs_input_grad[2]:
probs_grad = None
return act_grad, None, probs_grad, None
def moe_permute(
inp: torch.Tensor,
indices: torch.Tensor,
routing_map: torch.Tensor,
num_out_tokens: int = -1,
max_token_num: int = -1,
map_type: str = "mask",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Permute the tokens based on the indices. Token with the same index will be grouped together.
Permute the tokens based on the routing_map. Token with the same index will be grouped together.
Tokens with the same designated expert will be grouped together.
The routing_map indicates which experts were selected by each token.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
indices: torch.Tensor
The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'.
routing_map: torch.Tensor
The token to expert mapping tensor.
If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
The values in it are the routed expert indices.
num_out_tokens: int, default = -1
The effective output token count, representing the number of tokens not dropped.
By default, set to '-1', meaning no tokens are dropped.
......@@ -247,14 +466,23 @@ def moe_permute(
The maximum number of tokens, used for workspace allocation.
By default, set to '-1', meaning the calculation of the size of workspace is
automatically taken over by the operator.
map_type: str, default = 'mask'
Type of the routing map tensor.
Options are: 'mask', 'index'.
"""
return _moe_permute.apply(inp, indices, num_out_tokens, max_token_num)
if map_type == "index":
return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
if map_type == "mask":
return _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens)
raise ValueError("map_type should be one of 'mask' or 'index'")
def moe_unpermute(
inp: torch.Tensor,
row_id_map: torch.Tensor,
probs: torch.Tensor = None,
restore_shape: torch.Tensor = None,
map_type: str = "mask",
) -> torch.Tensor:
"""
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
......@@ -271,5 +499,109 @@ def moe_unpermute(
The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
restore_shape: torch.Tensor
The output shape after the unpermute operation.
map_type: str, default = 'mask'
Type of the routing map tensor. Should be the same as the value passed to moe_permute.
Options are: 'mask', 'index'.
"""
if map_type == "index":
return _moe_unpermute_index_map.apply(inp, row_id_map, probs)
if map_type == "mask":
return _moe_unpermute_mask_map.apply(inp, row_id_map, probs, restore_shape)
raise ValueError("map_type should be one of 'mask' or 'index'")
class _moe_chunk_sort(torch.autograd.Function):
"""functional MoE chunk permute"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
split_sizes: torch.Tensor,
sorted_idxs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
if not inp.numel():
return inp, torch.tensor([], device=inp.device)
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert split_sizes.is_cuda, "TransformerEngine needs CUDA."
assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA."
num_tokens, hidden_size = inp.shape
num_splits = split_sizes.size(0)
assert num_splits == sorted_idxs.size(0)
fp8 = isinstance(inp, Float8Tensor)
if fp8:
fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
inp = inp._data
output, row_id_map = triton_permutation.sort_chunks_by_idx(
inp,
split_sizes,
sorted_idxs,
num_tokens,
hidden_size,
num_splits,
)
if fp8:
output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv)
ctx.save_for_backward(row_id_map)
ctx.num_tokens = num_tokens
ctx.hidden_size = hidden_size
return output
@staticmethod
def backward(
ctx,
permuted_act_grad: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
if not permuted_act_grad.numel():
return permuted_act_grad, None, None
act_grad = None
if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors
fp8 = isinstance(permuted_act_grad, Float8Tensor)
if fp8:
fp8_dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
permuted_act_grad = permuted_act_grad._data
act_grad = triton_permutation.sort_chunks_by_map(
permuted_act_grad,
row_id_map,
ctx.num_tokens,
ctx.hidden_size,
)
if fp8:
act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
)
return act_grad, None, None
def moe_sort_chunks_by_index(
inp: torch.Tensor,
split_sizes: torch.Tensor,
sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Split and sort the input tensor based on the split_sizes and sorted indices.
The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
according to the sorted_indices.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
split_sizes: torch.Tensor
Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices: torch.Tensor
Chunk indices used to permute the chunks.
"""
return _moe_unpermute.apply(inp, row_id_map, probs)
return _moe_chunk_sort.apply(inp, split_sizes, sorted_index)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Kernels written with OpenAI Triton."""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Permutation kernels written with OpenAI Triton."""
from typing import Union
import torch
import triton
import triton.language as tl
from transformer_engine_torch import DType as TE_DType
@triton.jit
def _row_id_map_pass_1_kernel(
# pointers
routing_map_ptr,
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# strides
stride_routing_map_token,
stride_routing_map_expert,
# metas
BLOCK_SIZE: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
expert_token_mask = tl.load(
routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token,
mask=(offset < num_tokens),
other=0,
).to(tl.int64)
row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask
tl.store(
row_id_map_ptr + pid_m * num_tokens + offset,
row_id_within_token_block,
mask=offset < num_tokens,
)
n_tokens_per_block = tl.sum(expert_token_mask)
tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block)
@triton.jit
def _row_id_map_pass_2_kernel(
# pointers
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# metas
WORKSPACE_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
row_id_within_token_block = tl.load(
row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0
)
workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx)
row_id = tl.where(
row_id_within_token_block == 0,
-1,
row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1,
)
tl.store(
row_id_map_ptr + pid_m * num_tokens + offset,
row_id,
mask=(offset < num_tokens),
)
def make_row_id_map(
routing_map: torch.Tensor,
num_tokens: int,
num_experts: int,
):
# pylint: disable=missing-function-docstring
row_id_map = torch.empty((num_experts, num_tokens), dtype=torch.int64, device="cuda")
block_size = 256
grid = (num_experts, triton.cdiv(num_tokens, block_size))
workspace_tensor = torch.empty(grid, dtype=torch.int64, device="cuda")
# block cumsum
_row_id_map_pass_1_kernel[grid](
routing_map,
row_id_map,
workspace_tensor,
num_tokens,
routing_map.stride(0),
routing_map.stride(1),
block_size,
)
# cumsum all and process the mask
_row_id_map_pass_2_kernel[grid](
row_id_map,
workspace_tensor,
num_tokens,
triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)),
block_size,
)
return row_id_map
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit
def _permute_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
# sizes
num_tokens,
num_experts,
hidden_size,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
# metas
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
cur_pos = 0
while cur_pos < hidden_size:
cur_off = cur_pos + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size
input_off = pid * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
for expert_idx in range(num_experts):
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if dst_row != -1:
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
tl.store(output_ptr + output_off, inp, mask=mask)
cur_pos += BLOCK_SIZE
def permute_with_mask_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
hidden_size: int,
):
# pylint: disable=missing-function-docstring
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
grid = (num_tokens,)
_permute_kernel[grid](
inp,
output,
row_id_map,
num_tokens,
num_experts,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
)
return output
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit
def _unpermute_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
# sizes
num_tokens,
num_experts,
hidden_size,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_probs_expert,
# metas
WITH_PROBS: tl.constexpr,
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
if FP8_DTYPE == "e5m2":
compute_type = tl.float16
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
compute_type = tl.float16
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
compute_type = input_ptr.dtype.element_ty
assert FP8_DTYPE is None
pid = tl.program_id(0)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
for expert_idx in range(num_experts):
src_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if src_row != -1:
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True).to(compute_type)
if WITH_PROBS:
prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off).to(compute_type)
inp *= prob
accumulator += inp
if FP8_DTYPE is not None:
if not WITH_PROBS:
# Directly adding these value may cause overflow for fp8, we scale it here.
# The outside fp8_scale_inv is also scaled in the meantime.
accumulator /= num_experts
accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True)
output_off = pid * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask)
current_start += BLOCK_SIZE
def unpermute_with_mask_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
probs: Union[torch.Tensor, None],
num_tokens: int,
num_experts: int,
hidden_size: int,
fp8_dtype: TE_DType,
):
# pylint: disable=missing-function-docstring
if fp8_dtype == TE_DType.kFloat8E5M2:
fp8_dtype = "e5m2"
elif fp8_dtype == TE_DType.kFloat8E4M3:
fp8_dtype = "e4m3"
else:
fp8_dtype = None
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
grid = (num_tokens,)
_unpermute_kernel[grid](
inp,
output,
row_id_map,
probs,
num_tokens,
num_experts,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
probs.stride(0) if probs is not None else None,
probs.stride(1) if probs is not None else None,
WITH_PROBS=probs is not None,
FP8_DTYPE=fp8_dtype,
)
return output
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit
def _unpermute_bwd_with_probs_kernel(
# pointers
fwd_output_grad_ptr,
fwd_input_grad_ptr,
fwd_input_ptr,
probs_ptr,
probs_grad_ptr,
row_id_map_ptr,
# sizes
num_tokens,
num_experts,
hidden_size,
# strides
stride_fwd_output_grad_token,
stride_fwd_output_grad_hidden,
stride_fwd_input_grad_token,
stride_fwd_input_grad_hidden,
stride_fwd_input_token,
stride_fwd_input_hidden,
stride_probs_token,
stride_probs_expert,
stride_probs_grad_token,
stride_probs_grad_expert,
# metas
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
if FP8_DTYPE == "e5m2":
compute_type = tl.float16
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
compute_type = tl.float16
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
compute_type = fwd_output_grad_ptr.dtype.element_ty
assert FP8_DTYPE is None
pid = tl.program_id(0)
for expert_idx in range(num_experts):
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if dst_row != -1:
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_off = (
pid * stride_fwd_output_grad_token
+ current_offset * stride_fwd_output_grad_hidden
)
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True).to(compute_type)
probs_off = pid * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + probs_off).to(compute_type)
output = inp * prob
if FP8_DTYPE is not None:
output = output.to(data_type).to(pytorch_tensor_dtype, bitcast=True)
output_off = (
dst_row * stride_fwd_input_grad_token
+ current_offset * stride_fwd_input_grad_hidden
)
tl.store(fwd_input_grad_ptr + output_off, output, mask=mask)
fwd_input_off = (
dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
)
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
if FP8_DTYPE is not None:
fwd_input = fwd_input.to(data_type, bitcast=True)
prob_grad_accum += fwd_input.to(tl.float32) * inp.to(tl.float32)
current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum)
probs_grad_off = pid * stride_probs_grad_token + expert_idx * stride_probs_grad_expert
tl.store(probs_grad_ptr + probs_grad_off, probs_grad)
else:
probs_grad_off = pid * stride_probs_grad_token + expert_idx * stride_probs_grad_expert
tl.store(probs_grad_ptr + probs_grad_off, 0.0)
def unpermute_with_mask_map_bwd_with_probs(
fwd_output_grad: torch.Tensor,
row_id_map: torch.Tensor,
fwd_input: torch.Tensor,
probs: torch.Tensor,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
hidden_size: int,
fp8_dtype: TE_DType,
):
# pylint: disable=missing-function-docstring
if fp8_dtype == TE_DType.kFloat8E5M2:
fp8_dtype = "e5m2"
elif fp8_dtype == TE_DType.kFloat8E4M3:
fp8_dtype = "e4m3"
else:
fp8_dtype = None
act_grad = torch.empty(
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
)
probs_grad = torch.empty((num_tokens, num_experts), dtype=probs.dtype, device="cuda")
grid = (num_tokens,)
_unpermute_bwd_with_probs_kernel[grid](
fwd_output_grad,
act_grad,
fwd_input,
probs,
probs_grad,
row_id_map,
num_tokens,
num_experts,
hidden_size,
fwd_output_grad.stride(0),
fwd_output_grad.stride(1),
act_grad.stride(0),
act_grad.stride(1),
fwd_input.stride(0),
fwd_input.stride(1),
probs.stride(0),
probs.stride(1),
probs_grad.stride(0),
probs_grad.stride(1),
fp8_dtype,
)
return act_grad, probs_grad
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit
def _sort_chunks_by_idxs_kernel(
# pointers
input_ptr,
split_sizes_ptr,
sorted_indices_ptr,
output_ptr,
dst_rows_ptr,
# sizes
num_splits,
hidden_size,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
# metas
IDX_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
load_split_offset = tl.arange(0, IDX_LOAD_WIDTH)
sorted_indices = tl.load(
sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits
)
# get chunk idx of the current token in the input tensor
input_chunk_idx = -1
in_chunk_offset = tl.zeros([], dtype=tl.int64)
acc_chunk_sizes = tl.zeros([], dtype=tl.int64)
cursor = 0
while cursor < num_splits:
cur_chunk_size = tl.load(split_sizes_ptr + cursor).to(tl.int64)
acc_chunk_sizes += cur_chunk_size
if input_chunk_idx == -1 and acc_chunk_sizes > pid:
input_chunk_idx = cursor
in_chunk_offset = pid - (acc_chunk_sizes - cur_chunk_size)
cursor += 1
# get chunk idx of the current token in the output tensor
output_chunk_idx = 0
cursor = 0
while cursor < num_splits:
cur_input_idx = tl.load(sorted_indices_ptr + cursor)
if cur_input_idx == input_chunk_idx:
output_chunk_idx = cursor
cursor += 1
# make row_id_map
output_split_sizes = tl.load(
split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits
).to(tl.int64)
output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
tl.store(dst_rows_ptr + pid, dst_row)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_offsets = pid * stride_input_token + current_offset * stride_input_hidden
output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden
inp = tl.load(input_ptr + input_offsets, mask=mask)
tl.store(output_ptr + output_offsets, inp, mask=mask)
current_start += BLOCK_SIZE
def sort_chunks_by_idx(
inp: torch.Tensor,
split_sizes: torch.Tensor,
sorted_indices: torch.Tensor,
num_tokens: int,
hidden_size: int,
num_splits: int,
):
# pylint: disable=missing-function-docstring
row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda")
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
grid = (num_tokens,)
_sort_chunks_by_idxs_kernel[grid](
inp,
split_sizes,
sorted_indices,
output,
row_id_map,
num_splits,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
triton.next_power_of_2(num_splits),
)
return output, row_id_map
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit
def _sort_chunks_by_map(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
# sizes
hidden_size,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
# metas
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
dst_row = tl.load(row_id_map_ptr + pid)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_offsets = dst_row * stride_input_token + current_offset * stride_input_hidden
output_offsets = pid * stride_output_token + current_offset * stride_output_hidden
inp = tl.load(input_ptr + input_offsets, mask=mask)
tl.store(output_ptr + output_offsets, inp, mask=mask)
current_start += BLOCK_SIZE
def sort_chunks_by_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
num_tokens: int,
hidden_size: int,
):
# pylint: disable=missing-function-docstring
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
grid = (num_tokens,)
_sort_chunks_by_map[grid](
inp,
output,
row_id_map,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
)
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment