Unverified Commit eb9857d6 authored by hx's avatar hx Committed by GitHub
Browse files

[MoE][PyTorch] Add prob permutation to mask-based MoE permutation; Fix FP8 related codes (#1468)



* add prob permute; fix fp8tensor
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert unnecessary changes in UT
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* remove unnecessary probs dtype convert
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* keep the output nums if probs is not provided
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refine the doc string
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* fix lint
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* use fp32 compute type
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* style fix
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* fix empty input return
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* separate prob related functions out
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent b39397c5
......@@ -48,10 +48,14 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.moe_permute
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
This diff is collapsed.
......@@ -76,8 +76,10 @@ from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.permutation import (
moe_permute,
moe_permute_with_probs,
moe_unpermute,
moe_sort_chunks_by_index,
moe_sort_chunks_by_index_with_probs,
)
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.fp8 import fp8_model_init
......
......@@ -261,13 +261,17 @@ class _moe_permute_mask_map(torch.autograd.Function):
inp: torch.Tensor,
routing_map: torch.Tensor,
num_out_tokens: int,
probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
if not inp.numel():
return inp, torch.tensor([], device=inp.device)
ctx.probs = probs
return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device)
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert routing_map.is_cuda, "TransformerEngine needs CUDA."
if probs is not None:
assert probs.is_cuda, "TransformerEngine needs CUDA."
assert inp.size(0) == routing_map.size(0), "Permute not possible"
num_tokens, hidden_size = inp.size()
......@@ -282,48 +286,60 @@ class _moe_permute_mask_map(torch.autograd.Function):
if fp8:
fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
output = triton_permutation.permute_with_mask_map(
output, permuted_probs = triton_permutation.permute_with_mask_map(
inp,
row_id_map,
probs,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
)
if fp8:
output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv)
output = Float8Tensor(
data=output,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=output.shape,
dtype=fake_dtype,
)
ctx.save_for_backward(row_id_map)
ctx.num_experts = num_experts
ctx.num_tokens = num_tokens
ctx.hidden_size = hidden_size
return output, row_id_map
return output, row_id_map, permuted_probs
@staticmethod
def backward(
ctx,
permuted_act_grad: torch.Tensor,
_,
permuted_probs_grad: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
if not permuted_act_grad.numel():
return permuted_act_grad, None, None
return permuted_act_grad, None, None, ctx.probs
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors
fp8 = isinstance(permuted_act_grad, Float8Tensor)
if fp8:
fp8_dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
fake_dtype = permuted_act_grad.dtype
permuted_act_grad = permuted_act_grad._data
else:
fp8_dtype = None
act_grad = triton_permutation.unpermute_with_mask_map(
act_grad, probs_grad = triton_permutation.unpermute_with_mask_map(
permuted_act_grad,
row_id_map,
None,
permuted_probs_grad,
ctx.num_tokens,
ctx.num_experts,
ctx.hidden_size,
......@@ -334,8 +350,12 @@ class _moe_permute_mask_map(torch.autograd.Function):
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv * ctx.num_experts,
shape=act_grad.shape,
dtype=fake_dtype,
)
return act_grad, None, None
if not ctx.needs_input_grad[3]:
probs_grad = None
return act_grad, None, None, probs_grad
class _moe_unpermute_mask_map(torch.autograd.Function):
......@@ -346,12 +366,12 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
ctx,
inp: torch.Tensor,
row_id_map: torch.Tensor,
probs: torch.Tensor,
merging_probs: torch.Tensor,
restore_shape: torch.Size,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if not inp.numel():
ctx.probs = probs
ctx.merging_probs = merging_probs
return inp
if restore_shape is None:
......@@ -359,15 +379,9 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
num_tokens, hidden_size = restore_shape
num_experts = row_id_map.size(0)
with_probs = probs is not None
with_probs = merging_probs is not None
if with_probs:
assert probs.is_cuda, "TransformerEngine needs CUDA."
if probs.dtype != torch.float32:
warnings.warn(
f"The data type of the input `probs` of Unpermute is {probs.dtype}! "
"The recommended type is torch.float32."
)
probs = probs.to(torch.float32)
assert merging_probs.is_cuda, "TransformerEngine needs CUDA."
# Device check
assert inp.is_cuda, "TransformerEngine needs CUDA."
......@@ -380,13 +394,15 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
fp8_scale_inv = inp._scale_inv * num_experts
else:
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
fp8_dtype = None
unpermuted_output = triton_permutation.unpermute_with_mask_map(
unpermuted_output, _ = triton_permutation.unpermute_with_mask_map(
inp,
row_id_map,
probs,
merging_probs,
None,
num_tokens,
num_experts,
hidden_size,
......@@ -394,11 +410,15 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
)
if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
data=unpermuted_output,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=unpermuted_output.shape,
dtype=fake_dtype,
)
if with_probs:
ctx.save_for_backward(inp, row_id_map, probs)
ctx.save_for_backward(inp, row_id_map, merging_probs)
else:
ctx.save_for_backward(row_id_map)
ctx.num_experts = num_experts
......@@ -412,13 +432,13 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
def backward(ctx, unpermuted_act_grad):
# pylint: disable=missing-function-docstring
if not unpermuted_act_grad.numel():
return unpermuted_act_grad, None, ctx.probs, None
return unpermuted_act_grad, None, ctx.merging_probs, None
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
if ctx.with_probs:
fwd_input, row_id_map, probs = ctx.saved_tensors
fwd_input, row_id_map, merging_probs = ctx.saved_tensors
else:
(row_id_map,) = ctx.saved_tensors
......@@ -426,26 +446,30 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
if fp8:
fp8_dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
fake_dtype = unpermuted_act_grad.dtype
unpermuted_act_grad = unpermuted_act_grad._data
else:
fp8_dtype = None
if ctx.with_probs:
act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_probs(
act_grad, probs_grad = (
triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs(
unpermuted_act_grad,
row_id_map,
fwd_input,
probs,
merging_probs,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
ctx.hidden_size,
fp8_dtype,
)
)
else:
act_grad = triton_permutation.permute_with_mask_map(
act_grad, _ = triton_permutation.permute_with_mask_map(
unpermuted_act_grad,
row_id_map,
None,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
......@@ -454,7 +478,11 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
if fp8:
act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=act_grad.shape,
dtype=fake_dtype,
)
if not ctx.needs_input_grad[2]:
......@@ -494,20 +522,56 @@ def moe_permute(
map_type: str, default = 'mask'
Type of the routing map tensor.
Options are: 'mask', 'index'.
Refer to `routing_map` for more details.
"""
if map_type == "index":
return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num)
if map_type == "mask":
return _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens)
output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None)
return output, row_id_map
raise ValueError("map_type should be one of 'mask' or 'index'")
def moe_permute_with_probs(
inp: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
num_out_tokens: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Permute the tokens and probs based on the routing_map.
Token with the same index will be grouped together.
Tokens with the same designated expert will be grouped together.
The routing_map indicates which experts were selected by each token.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens, num_experts]. It will be permuted with the tokens
according to the routing_map.
routing_map: torch.Tensor
The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not.
num_out_tokens: int, default = -1
The effective output token count, representing the number of tokens not dropped.
By default, set to '-1', meaning no tokens are dropped.
"""
output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
inp, routing_map, num_out_tokens, probs
)
return output, permuted_probs, row_id_map
def moe_unpermute(
inp: torch.Tensor,
row_id_map: torch.Tensor,
probs: torch.Tensor = None,
merging_probs: torch.Tensor = None,
restore_shape: torch.Tensor = None,
map_type: str = "mask",
probs: torch.Tensor = None,
) -> torch.Tensor:
"""
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
......@@ -520,7 +584,7 @@ def moe_unpermute(
row_id_map: torch.Tensor
The tensor of a mapping table for sorted indices used to unpermute the tokens,
which is the second output tensor of `Permute`.
probs: torch.Tensor
merging_probs: torch.Tensor, default = None
The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
......@@ -529,11 +593,20 @@ def moe_unpermute(
map_type: str, default = 'mask'
Type of the routing map tensor. Should be the same as the value passed to moe_permute.
Options are: 'mask', 'index'.
probs: torch.Tensor, default = None
Renamed to merging_probs. Keep for backward compatibility.
"""
if probs is not None:
if merging_probs is not None:
raise ValueError(
"Both merging_probs and probs kwarg are provided. probs is deprecated."
)
warnings.warn("probs kwarg is deprecated. Use merging_probs kwarg instead.")
merging_probs = probs
if map_type == "index":
return _moe_unpermute_index_map.apply(inp, row_id_map, probs)
return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs)
if map_type == "mask":
return _moe_unpermute_mask_map.apply(inp, row_id_map, probs, restore_shape)
return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape)
raise ValueError("map_type should be one of 'mask' or 'index'")
......@@ -546,14 +619,17 @@ class _moe_chunk_sort(torch.autograd.Function):
inp: torch.Tensor,
split_sizes: torch.Tensor,
sorted_idxs: torch.Tensor,
probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# pylint: disable=missing-function-docstring
if not inp.numel():
return inp, torch.tensor([], device=inp.device)
return inp, probs
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert split_sizes.is_cuda, "TransformerEngine needs CUDA."
assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA."
if probs is not None:
assert probs.is_cuda, "TransformerEngine needs CUDA."
num_tokens, hidden_size = inp.shape
num_splits = split_sizes.size(0)
......@@ -563,51 +639,69 @@ class _moe_chunk_sort(torch.autograd.Function):
if fp8:
fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
output, row_id_map = triton_permutation.sort_chunks_by_idx(
output, row_id_map, permuted_probs = triton_permutation.sort_chunks_by_idx(
inp,
split_sizes,
sorted_idxs,
probs,
num_tokens,
hidden_size,
num_splits,
)
if fp8:
output = Float8Tensor(data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv)
output = Float8Tensor(
data=output,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=output.shape,
dtype=fake_dtype,
)
ctx.save_for_backward(row_id_map)
ctx.num_tokens = num_tokens
ctx.hidden_size = hidden_size
return output
return output, permuted_probs
@staticmethod
def backward(
ctx,
permuted_act_grad: torch.Tensor,
permuted_probs_grad: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
if not permuted_act_grad.numel():
return permuted_act_grad, None, None
return permuted_act_grad, None, None, permuted_probs_grad
act_grad = None
probs_grad = None
if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors
fp8 = isinstance(permuted_act_grad, Float8Tensor)
if fp8:
fp8_dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
fake_dtype = permuted_act_grad.dtype
permuted_act_grad = permuted_act_grad._data
act_grad = triton_permutation.sort_chunks_by_map(
act_grad, probs_grad = triton_permutation.sort_chunks_by_map(
permuted_act_grad,
row_id_map,
permuted_probs_grad,
ctx.num_tokens,
ctx.hidden_size,
)
if fp8:
act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=act_grad.shape,
dtype=fake_dtype,
)
return act_grad, None, None
if not ctx.needs_input_grad[3]:
probs_grad = None
return act_grad, None, None, probs_grad
def moe_sort_chunks_by_index(
......@@ -629,4 +723,33 @@ def moe_sort_chunks_by_index(
sorted_indices: torch.Tensor
Chunk indices used to permute the chunks.
"""
return _moe_chunk_sort.apply(inp, split_sizes, sorted_index)
output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None)
return output
def moe_sort_chunks_by_index_with_probs(
inp: torch.Tensor,
probs: torch.Tensor,
split_sizes: torch.Tensor,
sorted_index: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Split and sort the input tensor and probs based on the split_sizes and sorted indices.
The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted
according to the sorted_indices.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens]. It will be permuted with the tokens according to
the split_sizes and sorted_indices.
split_sizes: torch.Tensor
Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices: torch.Tensor
Chunk indices used to permute the chunks.
"""
output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs)
return output, permuted_probs
......@@ -125,6 +125,8 @@ def _permute_kernel(
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
permuted_probs_ptr,
# sizes
num_tokens,
num_experts,
......@@ -134,7 +136,11 @@ def _permute_kernel(
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_probs_expert,
stride_permuted_probs_token,
# metas
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
......@@ -149,12 +155,19 @@ def _permute_kernel(
if dst_row != -1:
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
tl.store(output_ptr + output_off, inp, mask=mask)
if PERMUTE_PROBS:
if cur_pos == 0:
prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off)
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
cur_pos += BLOCK_SIZE
def permute_with_mask_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
probs: torch.Tensor,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
......@@ -162,11 +175,17 @@ def permute_with_mask_map(
):
# pylint: disable=missing-function-docstring
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if probs is not None:
permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
else:
permuted_probs = None
grid = (num_tokens,)
_permute_kernel[grid](
inp,
output,
row_id_map,
probs,
permuted_probs,
num_tokens,
num_experts,
hidden_size,
......@@ -174,8 +193,12 @@ def permute_with_mask_map(
inp.stride(1),
output.stride(0),
output.stride(1),
probs.stride(0) if probs is not None else None,
probs.stride(1) if probs is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None,
PERMUTE_PROBS=probs is not None,
)
return output
return output, permuted_probs
@triton.autotune(
......@@ -194,7 +217,9 @@ def _unpermute_kernel(
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
merging_probs_ptr,
permuted_probs_ptr,
unpermuted_probs_ptr,
# sizes
num_tokens,
num_experts,
......@@ -204,24 +229,27 @@ def _unpermute_kernel(
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_probs_expert,
stride_merging_probs_token,
stride_merging_probs_expert,
stride_permuted_probs_token,
stride_unpermuted_probs_token,
stride_unpermuted_probs_expert,
# metas
WITH_PROBS: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
if FP8_DTYPE == "e5m2":
compute_type = tl.float16
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
compute_type = tl.float16
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
compute_type = input_ptr.dtype.element_ty
data_type = input_ptr.dtype.element_ty
assert FP8_DTYPE is None
compute_type = tl.float32
pid = tl.program_id(0)
current_start = 0
......@@ -235,18 +263,35 @@ def _unpermute_kernel(
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True).to(compute_type)
if WITH_PROBS:
prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off).to(compute_type)
inp *= prob
inp = inp.to(data_type, bitcast=True)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
merging_prob_off = (
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
inp *= merging_prob
accumulator += inp
if PERMUTE_PROBS:
if current_start == 0:
unpermuted_prob_off = (
pid * stride_unpermuted_probs_token
+ expert_idx * stride_unpermuted_probs_expert
)
if src_row != -1:
permuted_prob_off = src_row * stride_permuted_probs_token
prob = tl.load(permuted_probs_ptr + permuted_prob_off)
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
else:
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0)
if FP8_DTYPE is not None:
if not WITH_PROBS:
if not WITH_MERGING_PROBS:
# Directly adding these value may cause overflow for fp8, we scale it here.
# The outside fp8_scale_inv is also scaled in the meantime.
accumulator /= num_experts
accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True)
else:
accumulator = accumulator.to(data_type)
output_off = pid * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask)
current_start += BLOCK_SIZE
......@@ -255,7 +300,8 @@ def _unpermute_kernel(
def unpermute_with_mask_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
probs: Union[torch.Tensor, None],
merging_probs: Union[torch.Tensor, None],
permuted_probs: Union[torch.Tensor, None],
num_tokens: int,
num_experts: int,
hidden_size: int,
......@@ -269,12 +315,20 @@ def unpermute_with_mask_map(
else:
fp8_dtype = None
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if permuted_probs is not None:
unpermuted_probs = torch.empty(
(num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda"
)
else:
unpermuted_probs = None
grid = (num_tokens,)
_unpermute_kernel[grid](
inp,
output,
row_id_map,
probs,
merging_probs,
permuted_probs,
unpermuted_probs,
num_tokens,
num_experts,
hidden_size,
......@@ -282,12 +336,16 @@ def unpermute_with_mask_map(
inp.stride(1),
output.stride(0),
output.stride(1),
probs.stride(0) if probs is not None else None,
probs.stride(1) if probs is not None else None,
WITH_PROBS=probs is not None,
merging_probs.stride(0) if merging_probs is not None else None,
merging_probs.stride(1) if merging_probs is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None,
unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None,
FP8_DTYPE=fp8_dtype,
)
return output
return output, unpermuted_probs
@triton.autotune(
......@@ -301,13 +359,13 @@ def unpermute_with_mask_map(
key=["hidden_size"],
)
@triton.jit
def _unpermute_bwd_with_probs_kernel(
def _unpermute_bwd_with_merging_probs_kernel(
# pointers
fwd_output_grad_ptr,
fwd_input_grad_ptr,
fwd_input_ptr,
probs_ptr,
probs_grad_ptr,
merging_probs_ptr,
merging_probs_grad_ptr,
row_id_map_ptr,
# sizes
num_tokens,
......@@ -320,31 +378,30 @@ def _unpermute_bwd_with_probs_kernel(
stride_fwd_input_grad_hidden,
stride_fwd_input_token,
stride_fwd_input_hidden,
stride_probs_token,
stride_probs_expert,
stride_probs_grad_token,
stride_probs_grad_expert,
stride_merging_probs_token,
stride_merging_probs_expert,
stride_merging_probs_grad_token,
stride_merging_probs_grad_expert,
# metas
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
if FP8_DTYPE == "e5m2":
compute_type = tl.float16
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
compute_type = tl.float16
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
compute_type = fwd_output_grad_ptr.dtype.element_ty
data_type = fwd_output_grad_ptr.dtype.element_ty
assert FP8_DTYPE is None
compute_type = tl.float32
pid = tl.program_id(0)
for expert_idx in range(num_experts):
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if dst_row != -1:
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
......@@ -355,12 +412,16 @@ def _unpermute_bwd_with_probs_kernel(
)
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True).to(compute_type)
probs_off = pid * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + probs_off).to(compute_type)
output = inp * prob
inp = inp.to(data_type, bitcast=True)
inp = inp.to(compute_type)
merging_prob_off = (
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
output = inp * merging_prob
output = output.to(data_type)
if FP8_DTYPE is not None:
output = output.to(data_type).to(pytorch_tensor_dtype, bitcast=True)
output = output.to(pytorch_tensor_dtype, bitcast=True)
output_off = (
dst_row * stride_fwd_input_grad_token
+ current_offset * stride_fwd_input_grad_hidden
......@@ -373,21 +434,27 @@ def _unpermute_bwd_with_probs_kernel(
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
if FP8_DTYPE is not None:
fwd_input = fwd_input.to(data_type, bitcast=True)
prob_grad_accum += fwd_input.to(tl.float32) * inp.to(tl.float32)
prob_grad_accum += fwd_input.to(compute_type) * inp
current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum)
probs_grad_off = pid * stride_probs_grad_token + expert_idx * stride_probs_grad_expert
tl.store(probs_grad_ptr + probs_grad_off, probs_grad)
probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
probs_grad_off = (
pid * stride_merging_probs_grad_token
+ expert_idx * stride_merging_probs_grad_expert
)
tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad)
else:
probs_grad_off = pid * stride_probs_grad_token + expert_idx * stride_probs_grad_expert
tl.store(probs_grad_ptr + probs_grad_off, 0.0)
probs_grad_off = (
pid * stride_merging_probs_grad_token
+ expert_idx * stride_merging_probs_grad_expert
)
tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0)
def unpermute_with_mask_map_bwd_with_probs(
def unpermute_with_mask_map_bwd_with_merging_probs(
fwd_output_grad: torch.Tensor,
row_id_map: torch.Tensor,
fwd_input: torch.Tensor,
probs: torch.Tensor,
merging_probs: torch.Tensor,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
......@@ -404,14 +471,16 @@ def unpermute_with_mask_map_bwd_with_probs(
act_grad = torch.empty(
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
)
probs_grad = torch.empty((num_tokens, num_experts), dtype=probs.dtype, device="cuda")
merging_probs_grad = torch.empty(
(num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda"
)
grid = (num_tokens,)
_unpermute_bwd_with_probs_kernel[grid](
_unpermute_bwd_with_merging_probs_kernel[grid](
fwd_output_grad,
act_grad,
fwd_input,
probs,
probs_grad,
merging_probs,
merging_probs_grad,
row_id_map,
num_tokens,
num_experts,
......@@ -422,13 +491,13 @@ def unpermute_with_mask_map_bwd_with_probs(
act_grad.stride(1),
fwd_input.stride(0),
fwd_input.stride(1),
probs.stride(0),
probs.stride(1),
probs_grad.stride(0),
probs_grad.stride(1),
merging_probs.stride(0),
merging_probs.stride(1),
merging_probs_grad.stride(0),
merging_probs_grad.stride(1),
fp8_dtype,
)
return act_grad, probs_grad
return act_grad, merging_probs_grad
@triton.autotune(
......@@ -449,6 +518,8 @@ def _sort_chunks_by_idxs_kernel(
sorted_indices_ptr,
output_ptr,
dst_rows_ptr,
probs_ptr,
permuted_probs_ptr,
# sizes
num_splits,
hidden_size,
......@@ -457,7 +528,10 @@ def _sort_chunks_by_idxs_kernel(
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_permuted_probs_token,
# metas
PERMUTE_PROBS: tl.constexpr,
IDX_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
......@@ -508,11 +582,18 @@ def _sort_chunks_by_idxs_kernel(
tl.store(output_ptr + output_offsets, inp, mask=mask)
current_start += BLOCK_SIZE
if PERMUTE_PROBS:
prob_off = pid * stride_probs_token
prob = tl.load(probs_ptr + prob_off)
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
def sort_chunks_by_idx(
inp: torch.Tensor,
split_sizes: torch.Tensor,
sorted_indices: torch.Tensor,
probs: torch.Tensor,
num_tokens: int,
hidden_size: int,
num_splits: int,
......@@ -520,6 +601,10 @@ def sort_chunks_by_idx(
# pylint: disable=missing-function-docstring
row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda")
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if probs is not None:
permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
else:
permuted_probs = None
grid = (num_tokens,)
_sort_chunks_by_idxs_kernel[grid](
inp,
......@@ -527,15 +612,20 @@ def sort_chunks_by_idx(
sorted_indices,
output,
row_id_map,
probs,
permuted_probs,
num_splits,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
triton.next_power_of_2(num_splits),
probs.stride(0) if probs is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None,
PERMUTE_PROBS=probs is not None,
IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits),
)
return output, row_id_map
return output, row_id_map, permuted_probs
@triton.autotune(
......@@ -554,6 +644,8 @@ def _sort_chunks_by_map(
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
permuted_probs_ptr,
# sizes
hidden_size,
# strides
......@@ -561,7 +653,10 @@ def _sort_chunks_by_map(
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_permuted_probs_token,
# metas
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
......@@ -575,25 +670,40 @@ def _sort_chunks_by_map(
inp = tl.load(input_ptr + input_offsets, mask=mask)
tl.store(output_ptr + output_offsets, inp, mask=mask)
current_start += BLOCK_SIZE
if PERMUTE_PROBS:
prob_off = dst_row * stride_probs_token
prob = tl.load(probs_ptr + prob_off)
permuted_prob_off = pid * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
def sort_chunks_by_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
probs: torch.Tensor,
num_tokens: int,
hidden_size: int,
):
# pylint: disable=missing-function-docstring
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if probs is not None:
permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
else:
permuted_probs = None
grid = (num_tokens,)
_sort_chunks_by_map[grid](
inp,
output,
row_id_map,
probs,
permuted_probs,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
probs.stride(0) if probs is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None,
PERMUTE_PROBS=probs is not None,
)
return output
return output, permuted_probs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment