Unverified Commit dc97cc9e authored by hx's avatar hx Committed by GitHub
Browse files

[PyTorch] Optimize the performance of permute fusion kernels (#1927)



* optimize permute
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 37da2d3b
......@@ -326,6 +326,7 @@ def _test_permutation_index_map(
te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
if not BENCHMARK:
torch.testing.assert_close(
pytorch_permute_output.float(),
te_permute_output_,
......@@ -351,7 +352,10 @@ def _test_permutation_index_map(
)
if with_probs:
torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols
probs.grad.float(),
te_probs.grad.float(),
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if not pytorch_permute_fwd_input.numel():
......@@ -538,6 +542,7 @@ def _test_permutation_mask_map(
te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
if not BENCHMARK:
torch.testing.assert_close(
pytorch_permute_output.float(),
te_permute_output_,
......@@ -564,7 +569,10 @@ def _test_permutation_mask_map(
)
if with_probs:
torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols
probs.grad.float(),
te_probs.grad.float(),
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if not pytorch_permute_fwd_input.numel():
......@@ -827,6 +835,7 @@ def _test_moe_chunk_sort(
te_output_ = te_output.float()
te_fwd_input_grad = te_fwd_input.grad.float()
if not BENCHMARK:
torch.testing.assert_close(
pytorch_output.float(),
te_output_,
......@@ -887,6 +896,7 @@ def _test_permutation_mask_map_alongside_probs(
topK,
num_out_tokens,
tp_size,
BENCHMARK=False,
):
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
......@@ -1016,6 +1026,7 @@ def _test_permutation_mask_map_alongside_probs(
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
if not BENCHMARK:
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
......@@ -1032,6 +1043,57 @@ def _test_permutation_mask_map_alongside_probs(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols
)
if BENCHMARK:
t1 = perf_test_cuda_kernel(
lambda: te_permute_with_probs(
te_permute_fwd_input, te_probs, routing_map, num_out_tokens=num_out_tokens
)
)
print(f"permute\t\tfwd: TE: {t1:.3f} ms")
te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs(
te_permute_fwd_input,
te_probs,
routing_map,
num_out_tokens=num_out_tokens,
)
te_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_permute_output,
te_permute_bwd_input,
forward_input=[te_permute_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"permute\t\tbwd: TE: {t2:.3f} ms")
chunk_sort_fwd_input = te_permute_output.detach()
chunk_sort_fwd_input.requires_grad_(True)
chunk_sort_fwd_probs = te_permuted_probs.detach()
chunk_sort_fwd_probs.requires_grad_(True)
t1 = perf_test_cuda_kernel(
lambda: te_sort_chunks_by_index_with_probs(
chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda
)
)
print(f"chunk sort\t\tfwd: TE: {t1:.3f} ms")
chunk_sort_output, _ = te_sort_chunks_by_index_with_probs(
chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
chunk_sort_output,
te_permute_bwd_input,
forward_input=[chunk_sort_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"chunk sort\t\tbwd: TE: {t2:.3f} ms")
def perf_test_cuda_kernel(cuda_kernel_fn):
if torch.cuda.is_available():
......@@ -1063,7 +1125,7 @@ if is_bf16_compatible():
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
......@@ -1092,7 +1154,7 @@ def test_permutation_index_map(
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
......@@ -1138,7 +1200,7 @@ def test_permutation_mask_map_empty_input(te_dtype):
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
......@@ -1193,7 +1255,7 @@ fp8_recipes = [
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
......@@ -1225,7 +1287,7 @@ def test_permutation_mask_map_fp8(
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_index_map_topk1_no_probs(
te_dtype,
......@@ -1252,7 +1314,7 @@ def test_permutation_index_map_topk1_no_probs(
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_mask_map_topk1_no_probs(
te_dtype,
......@@ -1279,7 +1341,7 @@ def test_permutation_mask_map_topk1_no_probs(
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("tp_size", [1, 2, 8])
@pytest.mark.parametrize("hidden_size", [4096])
def test_chunk_permutation(
......@@ -1372,5 +1434,108 @@ def test_permutation_single_case():
)
def benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
):
torch.cuda.nvtx.range_push(
f"{num_tokens}-{num_expert}-{hidden_size}-{topK}-{ep_size}-{tp_size}"
)
torch.cuda.nvtx.range_push("permutation_index_map_with_probs")
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=True,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_with_probs")
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=True,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_without_probs")
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=False,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs")
_test_permutation_mask_map_alongside_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
tp_size=tp_size,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
def benchmark_multiple_cases():
print("GPU:", torch.cuda.get_device_name(0))
# te_dtype = tex.DType.kFloat32
# te_dtype = tex.DType.kFloat16
te_dtype = tex.DType.kBFloat16
ep_size = 64
tp_size = 2
num_tokens = 4096
num_expert = 256
hidden_size = 7168
topK = 8
num_out_tokens = num_tokens * topK
benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
)
ep_size = 8
tp_size = 1
num_tokens = 8192 * 2
num_expert = 128
hidden_size = 4096
topK = 6
num_out_tokens = num_tokens * topK
benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
)
ep_size = 64
tp_size = 2
num_tokens = 16384
num_expert = 4
hidden_size = 7168
topK = 1
num_out_tokens = num_tokens * topK
benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
)
if __name__ == "__main__":
test_permutation_single_case()
benchmark_multiple_cases()
......@@ -349,7 +349,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
if restore_shape is None:
restore_shape = inp.shape
num_tokens, hidden_size = restore_shape
num_experts = row_id_map.size(0)
num_experts = (row_id_map.size(1) - 1) // 2
with_probs = merging_probs is not None
if with_probs:
......@@ -651,14 +651,20 @@ class _moe_chunk_sort(torch.autograd.Function):
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
output, row_id_map, permuted_probs = triton_permutation.sort_chunks_by_idx(
inp,
row_id_map = triton_permutation.make_chunk_sort_map(
split_sizes,
sorted_idxs,
num_tokens,
num_splits,
)
output, permuted_probs = triton_permutation.sort_chunks_by_map(
inp,
row_id_map,
probs,
num_tokens,
hidden_size,
num_splits,
is_forward=True,
)
if fp8:
output = Float8Tensor(
......@@ -700,6 +706,7 @@ class _moe_chunk_sort(torch.autograd.Function):
permuted_probs_grad,
ctx.num_tokens,
ctx.hidden_size,
is_forward=False,
)
if fp8:
act_grad = Float8Tensor(
......
......@@ -10,6 +10,72 @@ import torch
import triton
import triton.language as tl
from triton.language import core
from triton.language.standard import _log2
# The following three argsort related kernels are adapted from
# the issue https://github.com/triton-lang/triton/issues/3698
@triton.jit
def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr):
n_outer: tl.constexpr = x.numel >> n_dims
shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)]
y = tl.reshape(x, shape)
z = tl.reshape(indices, shape)
mask = tl.arange(0, 2)[None, :, None]
l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to(
x.dtype
)
r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to(
x.dtype
)
l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape)
r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape)
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
il_value = l_value.to(idtype, bitcast=True)
ir_value = r_value.to(idtype, bitcast=True)
ix = x.to(idtype, bitcast=True)
flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix))
ret = ix ^ flag1
flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix))
ind = indices ^ flag2
return ret.to(x.dtype, bitcast=True), ind
@triton.jit
def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr):
n_outer: tl.constexpr = x.numel >> n_dims
tl.static_assert(stage <= n_dims)
"""
order_type 0 == ascending
order_type 1 == descending
order_type 2 == alternating
"""
if order == 2:
shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage]
flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape)
else:
flip = tl.full(x.shape, value=order, dtype=tl.int32)
for i in tl.static_range(stage):
x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims)
return x, indices
@triton.jit
def _argsort(x, indices, n_dims: tl.constexpr):
for i in tl.static_range(1, n_dims + 1):
x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims)
return x, indices
@triton.jit
def _row_id_map_pass_1_kernel(
......@@ -22,6 +88,8 @@ def _row_id_map_pass_1_kernel(
# strides
stride_routing_map_token,
stride_routing_map_expert,
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
BLOCK_SIZE: tl.constexpr,
):
......@@ -32,10 +100,10 @@ def _row_id_map_pass_1_kernel(
routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token,
mask=(offset < num_tokens),
other=0,
).to(tl.int64)
).to(tl.int32)
row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask
tl.store(
row_id_map_ptr + pid_m * num_tokens + offset,
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
row_id_within_token_block,
mask=offset < num_tokens,
)
......@@ -50,6 +118,9 @@ def _row_id_map_pass_2_kernel(
workspace_ptr,
# sizes
num_tokens,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
WORKSPACE_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
......@@ -59,7 +130,9 @@ def _row_id_map_pass_2_kernel(
chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
row_id_within_token_block = tl.load(
row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
mask=(offset < num_tokens),
other=0,
)
workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
......@@ -70,23 +143,102 @@ def _row_id_map_pass_2_kernel(
row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1,
)
tl.store(
row_id_map_ptr + pid_m * num_tokens + offset,
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
row_id,
mask=(offset < num_tokens),
)
@triton.jit
def _row_id_map_pass_3_kernel(
# pointers
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
LOAD_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
n_dims: tl.constexpr = _log2(LOAD_SIZE)
off = tl.arange(0, LOAD_SIZE)
row_id_map = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off,
mask=off < num_experts,
other=-1,
)
n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0))
indices = off
sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims)
tl.store(
row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert,
sorted_map,
mask=off < n_routed,
)
tl.store(
row_id_map_ptr
+ pid * stride_row_id_map_token
+ (num_experts + off) * stride_row_id_map_expert,
indices,
mask=off < n_routed,
)
tl.store(
row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert,
n_routed,
)
def make_row_id_map(
routing_map: torch.Tensor,
num_tokens: int,
num_experts: int,
):
# pylint: disable=missing-function-docstring
row_id_map = torch.empty((num_experts, num_tokens), dtype=torch.int64, device="cuda")
block_size = 256
"""
Prepare the row_id_map for the permutation.
Parameters
----------
routing_map: torch.Tensor
Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
which experts are routed to which tokens. The values in it: 1 means the token is routed to
this expert and 0 means not.
num_tokens: int
Number of tokens in the input tensor.
num_experts: int
Number of experts in the input tensor.
Returns
-------
row_id_map: torch.Tensor
The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
For each token, the last item is the number of experts that are routed (n_routed).
The first n_routed items are the destination row indices in the permuted tokens.
The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding
to the first n_routed row indices above.
"""
row_id_map = torch.empty((num_tokens, num_experts * 2 + 1), dtype=torch.int32, device="cuda")
block_size = 1024
grid = (num_experts, triton.cdiv(num_tokens, block_size))
workspace_tensor = torch.empty(grid, dtype=torch.int64, device="cuda")
# block cumsum
workspace_tensor = torch.empty(grid, dtype=torch.int32, device="cuda")
# supposing num_tokens == 5, num_experts == 3, block_size == 3
# and we have a routing_map like this:
# [[1, 1, 0],
# [1, 0, 1],
# [0, 0, 1],
# [1, 1, 0],
# [0, 0, 0]]
# pass 1: block cumsum
# for each expert, compute the cumsum of every block_size tokens
# the row_id_map will be like this after pass 1 (r means useless values):
# [[1, 1, 0, r, r, r, r],
# [2, 0, 1, r, r, r, r],
# [0, 0, 2, r, r, r, r],
# [1, 1, 0, r, r, r, r],
# [0, 0, 0, r, r, r, r]]
_row_id_map_pass_1_kernel[grid](
routing_map,
row_id_map,
......@@ -94,16 +246,44 @@ def make_row_id_map(
num_tokens,
routing_map.stride(0),
routing_map.stride(1),
row_id_map.stride(0),
row_id_map.stride(1),
block_size,
)
# cumsum all and process the mask
# pass 2: cumsum all and process the mask
# process the block cumsum into the global cumsum and then into the dst row indices
# the row_id_map will be like this after pass 2 (r means useless value):
# [[ 0, 3, -1, r, r, r, r],
# [ 1, -1, 5, r, r, r, r],
# [-1, -1, 6, r, r, r, r],
# [ 2, 4, -1, r, r, r, r],
# [-1, -1, -1, r, r, r, r]]
_row_id_map_pass_2_kernel[grid](
row_id_map,
workspace_tensor,
num_tokens,
row_id_map.stride(0),
row_id_map.stride(1),
triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)),
block_size,
)
# pass 3: make the row_id_map from the sparse structure to the dense structure
# the row_id_map will be like this after pass 3 (r means useless value):
# [[3, 0, r, 1, 0, r, 2],
# [5, 1, r, 2, 0, r, 2],
# [6, r, r, 2, r, r, 1],
# [4, 2, r, 1, 0, r, 2],
# [r, r, r, r, r, r, 0]]
grid = (num_tokens,)
_row_id_map_pass_3_kernel[grid](
row_id_map,
num_experts,
row_id_map.stride(0),
row_id_map.stride(1),
triton.next_power_of_2(num_experts),
)
return row_id_map
......@@ -118,11 +298,12 @@ def _permute_kernel(
permuted_probs_ptr,
permuted_scale_ptr,
# sizes
num_tokens,
num_experts,
hidden_size,
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
scale_hidden_dim,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_input_token,
stride_input_hidden,
stride_output_token,
......@@ -139,35 +320,50 @@ def _permute_kernel(
PERMUTE_SCALE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
cur_pos = 0
while cur_pos < hidden_size:
cur_off = cur_pos + tl.arange(0, BLOCK_SIZE)
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size
input_off = pid * stride_input_token + cur_off * stride_input_hidden
input_off = pid_t * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim
scale_off = pid * stride_scale_token + cur_off * stride_scale_hidden
scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden
scale = tl.load(scale_ptr + scale_off, mask=mask_scale)
for expert_idx in range(num_experts):
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if dst_row != -1:
n_routed = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
)
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
tl.store(output_ptr + output_off, inp, mask=mask)
if PERMUTE_SCALE:
permuted_scale_off = (
dst_row * stride_permuted_scale_token
+ cur_off * stride_permuted_scale_hidden
dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS:
if cur_pos == 0:
prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off)
if pid_h == 0:
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
cur_pos += BLOCK_SIZE
if prob == 0.0:
# for routing_map padding
# dst_row != -1 and prob == 0.0 means that this slot is padded
tl.store(output_ptr + output_off, 0, mask=mask)
else:
tl.store(output_ptr + output_off, inp, mask=mask)
else:
tl.store(output_ptr + output_off, inp, mask=mask)
try:
......@@ -178,6 +374,8 @@ try:
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_permute_kernel)
......@@ -196,7 +394,30 @@ def permute_with_mask_map(
hidden_size: int,
scale_hidden_dim: int,
):
# pylint: disable=missing-function-docstring
"""
Permute the input tensor based on the row_id_map.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
row_id_map: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
probs: torch.Tensor
The probabilities of the input tensor. If it is not None, it will be permuted.
scale: torch.Tensor
The scale of the input tensor. If it is not None, it will be permuted.
num_tokens: int
Number of tokens in the input tensor.
num_experts: int
Number of experts in the input tensor.
num_out_tokens: int
Number of tokens in the permuted tensor.
hidden_size: int
Hidden size of the input tensor.
scale_hidden_dim: int
Hidden size of the scale tensor.
"""
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if probs is not None:
permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
......@@ -209,8 +430,8 @@ def permute_with_mask_map(
)
else:
permuted_scale = None
grid = (num_tokens,)
# pylint: disable=unnecessary-lambda-assignment
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_permute_kernel[grid](
inp,
output,
......@@ -219,10 +440,11 @@ def permute_with_mask_map(
scale,
permuted_probs,
permuted_scale,
num_tokens,
num_experts,
hidden_size,
scale_hidden_dim,
row_id_map.stride(0),
row_id_map.stride(1),
inp.stride(0),
inp.stride(1),
output.stride(0),
......@@ -250,10 +472,11 @@ def _unpermute_kernel(
permuted_probs_ptr,
unpermuted_probs_ptr,
# sizes
num_tokens,
num_experts,
hidden_size,
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_input_token,
stride_input_hidden,
stride_output_token,
......@@ -264,6 +487,7 @@ def _unpermute_kernel(
stride_unpermuted_probs_token,
stride_unpermuted_probs_expert,
# metas
PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
......@@ -271,41 +495,63 @@ def _unpermute_kernel(
data_type = input_ptr.dtype.element_ty
compute_type = tl.float32
pid = tl.program_id(0)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
if PERMUTE_PROBS:
# write 0.0 to probs_grad that are not routed
if pid_h == 0:
map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
unpermuted_prob_off = (
pid_t * stride_unpermuted_probs_token
+ stride_unpermuted_probs_expert * map_load_off
)
tl.store(
unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts
)
accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
for expert_idx in range(num_experts):
src_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if src_row != -1:
n_routed = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
merging_prob_off = (
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
inp *= merging_prob
accumulator += inp
if PERMUTE_PROBS:
if current_start == 0:
if pid_h == 0:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
unpermuted_prob_off = (
pid * stride_unpermuted_probs_token
pid_t * stride_unpermuted_probs_token
+ expert_idx * stride_unpermuted_probs_expert
)
if src_row != -1:
permuted_prob_off = src_row * stride_permuted_probs_token
prob = tl.load(permuted_probs_ptr + permuted_prob_off)
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
else:
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0)
accumulator = accumulator.to(data_type)
output_off = pid * stride_output_token + current_offset * stride_output_hidden
output_off = pid_t * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask)
current_start += BLOCK_SIZE
try:
......@@ -316,6 +562,8 @@ try:
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_unpermute_kernel)
......@@ -332,7 +580,27 @@ def unpermute_with_mask_map(
num_experts: int,
hidden_size: int,
):
# pylint: disable=missing-function-docstring
"""
Unpermute the input tensor based on the row_id_map.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_out_tokens, hidden_size]`.
row_id_map: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
merging_probs: torch.Tensor
The merging probabilities of the input tensor. If it is not None, it will be used as weights
to reduce the unpermuted tokens.
permuted_probs: torch.Tensor
The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
num_tokens: int
Number of tokens in the permuted tensor.
num_experts: int
Number of experts in the permuted tensor.
hidden_size: int
Hidden size of the permuted tensor.
"""
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if permuted_probs is not None:
unpermuted_probs = torch.empty(
......@@ -340,7 +608,8 @@ def unpermute_with_mask_map(
)
else:
unpermuted_probs = None
grid = (num_tokens,)
# pylint: disable=unnecessary-lambda-assignment
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_unpermute_kernel[grid](
inp,
output,
......@@ -348,9 +617,10 @@ def unpermute_with_mask_map(
merging_probs,
permuted_probs,
unpermuted_probs,
num_tokens,
num_experts,
hidden_size,
row_id_map.stride(0),
row_id_map.stride(1),
inp.stride(0),
inp.stride(1),
output.stride(0),
......@@ -360,6 +630,7 @@ def unpermute_with_mask_map(
permuted_probs.stride(0) if permuted_probs is not None else None,
unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None,
)
......@@ -376,10 +647,11 @@ def _unpermute_bwd_with_merging_probs_kernel(
merging_probs_grad_ptr,
row_id_map_ptr,
# sizes
num_tokens,
num_experts,
hidden_size,
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_fwd_output_grad_token,
stride_fwd_output_grad_hidden,
stride_fwd_input_grad_token,
......@@ -391,23 +663,37 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_grad_token,
stride_merging_probs_grad_expert,
# metas
PROBS_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = fwd_output_grad_ptr.dtype.element_ty
compute_type = tl.float32
pid = tl.program_id(0)
for expert_idx in range(num_experts):
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if dst_row != -1:
map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
token_probs_grad_off = (
pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off
)
tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts)
n_routed = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert
)
expert_idx = tl.load(
row_id_map_ptr
+ pid * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_off = (
pid * stride_fwd_output_grad_token
+ current_offset * stride_fwd_output_grad_hidden
pid * stride_fwd_output_grad_token + current_offset * stride_fwd_output_grad_hidden
)
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
......@@ -431,16 +717,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
probs_grad_off = (
pid * stride_merging_probs_grad_token
+ expert_idx * stride_merging_probs_grad_expert
pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert
)
tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad)
else:
probs_grad_off = (
pid * stride_merging_probs_grad_token
+ expert_idx * stride_merging_probs_grad_expert
)
tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0)
try:
......@@ -451,6 +730,8 @@ try:
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_unpermute_bwd_with_merging_probs_kernel)
......@@ -468,7 +749,28 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
num_out_tokens: int,
hidden_size: int,
):
# pylint: disable=missing-function-docstring
"""
Unpermute backward pass kernel with merging probs.
Parameters
----------
fwd_output_grad: torch.Tensor
The gradient of the output tensor of shape `[num_tokens, hidden_size]`.
row_id_map: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
fwd_input: torch.Tensor
The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
merging_probs: torch.Tensor
The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
num_tokens: int
Number of tokens in the permuted tensor.
num_experts: int
Number of experts in the permuted tensor.
num_out_tokens: int
Number of tokens in the output tensor.
hidden_size: int
Hidden size of the output tensor.
"""
act_grad = torch.empty(
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
)
......@@ -483,9 +785,10 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs,
merging_probs_grad,
row_id_map,
num_tokens,
num_experts,
hidden_size,
row_id_map.stride(0),
row_id_map.stride(1),
fwd_output_grad.stride(0),
fwd_output_grad.stride(1),
act_grad.stride(0),
......@@ -496,34 +799,21 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs.stride(1),
merging_probs_grad.stride(0),
merging_probs_grad.stride(1),
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
)
return act_grad, merging_probs_grad
@triton.jit
def _sort_chunks_by_idxs_kernel(
def _make_chunk_sort_map_kernel(
# pointers
input_ptr,
split_sizes_ptr,
sorted_indices_ptr,
output_ptr,
dst_rows_ptr,
probs_ptr,
permuted_probs_ptr,
# sizes
num_splits,
hidden_size,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_permuted_probs_token,
num_splits: tl.constexpr,
# metas
PERMUTE_PROBS: tl.constexpr,
IDX_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
......@@ -533,104 +823,58 @@ def _sort_chunks_by_idxs_kernel(
)
# get chunk idx of the current token in the input tensor
input_chunk_idx = -1
in_chunk_offset = tl.zeros([], dtype=tl.int64)
acc_chunk_sizes = tl.zeros([], dtype=tl.int64)
cursor = 0
while cursor < num_splits:
cur_chunk_size = tl.load(split_sizes_ptr + cursor).to(tl.int64)
acc_chunk_sizes += cur_chunk_size
if input_chunk_idx == -1 and acc_chunk_sizes > pid:
input_chunk_idx = cursor
in_chunk_offset = pid - (acc_chunk_sizes - cur_chunk_size)
cursor += 1
input_split_sizes = tl.load(
split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0
).to(tl.int32)
input_split_sizes_cumsum = tl.cumsum(input_split_sizes)
input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0)
input_chunk_idx = tl.sum(input_split_sizes_mask)
input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask)
in_chunk_offset = pid - input_split_sizes_presum
# get chunk idx of the current token in the output tensor
output_chunk_idx = 0
cursor = 0
while cursor < num_splits:
cur_input_idx = tl.load(sorted_indices_ptr + cursor)
if cur_input_idx == input_chunk_idx:
output_chunk_idx = cursor
cursor += 1
output_chunk_mask = tl.where(sorted_indices == input_chunk_idx, 1, 0)
output_chunk_idx = tl.argmax(output_chunk_mask, axis=-1)
# make row_id_map
output_split_sizes = tl.load(
split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits
).to(tl.int64)
).to(tl.int32)
output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
tl.store(dst_rows_ptr + pid, dst_row)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_offsets = pid * stride_input_token + current_offset * stride_input_hidden
output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden
inp = tl.load(input_ptr + input_offsets, mask=mask)
tl.store(output_ptr + output_offsets, inp, mask=mask)
current_start += BLOCK_SIZE
if PERMUTE_PROBS:
prob_off = pid * stride_probs_token
prob = tl.load(probs_ptr + prob_off)
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
try:
_sort_chunks_by_idxs_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)(_sort_chunks_by_idxs_kernel)
except RuntimeError:
pass
def sort_chunks_by_idx(
inp: torch.Tensor,
def make_chunk_sort_map(
split_sizes: torch.Tensor,
sorted_indices: torch.Tensor,
probs: torch.Tensor,
num_tokens: int,
hidden_size: int,
num_splits: int,
):
# pylint: disable=missing-function-docstring
row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda")
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if probs is not None:
permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
else:
permuted_probs = None
"""
Make a row_id_map for chunk sort.
Parameters
----------
split_sizes: torch.Tensor
The sizes of the chunks of shape `[num_splits,]`.
sorted_indices: torch.Tensor
The indices of the sorted chunks of shape `[num_splits,]`.
num_tokens: int
Number of tokens in the input tensor.
num_splits: int
Number of splits of split_sizes and sorted_indices.
"""
row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device="cuda")
grid = (num_tokens,)
_sort_chunks_by_idxs_kernel[grid](
inp,
_make_chunk_sort_map_kernel[grid](
split_sizes,
sorted_indices,
output,
row_id_map,
probs,
permuted_probs,
num_splits,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
probs.stride(0) if probs is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None,
PERMUTE_PROBS=probs is not None,
IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits),
)
return output, row_id_map, permuted_probs
return row_id_map
@triton.jit
......@@ -642,7 +886,7 @@ def _sort_chunks_by_map_kernel(
probs_ptr,
permuted_probs_ptr,
# sizes
hidden_size,
hidden_size: tl.constexpr,
# strides
stride_input_token,
stride_input_hidden,
......@@ -653,22 +897,27 @@ def _sort_chunks_by_map_kernel(
# metas
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
FORWARD: tl.constexpr,
):
pid = tl.program_id(0)
dst_row = tl.load(row_id_map_ptr + pid)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
if FORWARD:
src_row = pid_t
dst_row = tl.load(row_id_map_ptr + pid_t)
else:
src_row = tl.load(row_id_map_ptr + pid_t)
dst_row = pid_t
current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_offsets = dst_row * stride_input_token + current_offset * stride_input_hidden
output_offsets = pid * stride_output_token + current_offset * stride_output_hidden
input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden
output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden
inp = tl.load(input_ptr + input_offsets, mask=mask)
tl.store(output_ptr + output_offsets, inp, mask=mask)
current_start += BLOCK_SIZE
if PERMUTE_PROBS:
prob_off = dst_row * stride_probs_token
if pid_h == 0:
prob_off = src_row * stride_probs_token
prob = tl.load(probs_ptr + prob_off)
permuted_prob_off = pid * stride_permuted_probs_token
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
......@@ -680,6 +929,8 @@ try:
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_sort_chunks_by_map_kernel)
......@@ -693,14 +944,33 @@ def sort_chunks_by_map(
probs: torch.Tensor,
num_tokens: int,
hidden_size: int,
is_forward: bool,
):
# pylint: disable=missing-function-docstring
"""
Sort chunks with row_id_map.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`.
row_id_map: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens,]`.
probs: torch.Tensor
The probabilities of the input tensor. If it is not None, it will be permuted.
num_tokens: int
Number of tokens in the input tensor.
hidden_size: int
Hidden size of the input tensor.
is_forward: bool
Whether the sort is for forward or backward.
"""
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if probs is not None:
permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
else:
permuted_probs = None
grid = (num_tokens,)
# pylint: disable=unnecessary-lambda-assignment
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_sort_chunks_by_map_kernel[grid](
inp,
output,
......@@ -715,5 +985,6 @@ def sort_chunks_by_map(
probs.stride(0) if probs is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None,
PERMUTE_PROBS=probs is not None,
FORWARD=is_forward,
)
return output, permuted_probs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment