Unverified Commit c8e7cc02 authored by Autumn1998's avatar Autumn1998 Committed by GitHub
Browse files

[MoE] Support new fp8 recipes for permute_fusion (#1649)



* add support for new recipe on permute_fusion, rm fp unpermute
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* remove fp8 from index map
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* skip unsupported tests
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatartongliu <tongliu@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent d9eb0582
This diff is collapsed.
......@@ -333,7 +333,7 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
const transformer_engine::Tensor *input_fwd_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input_fwd);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T,
nvte_permute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr),
reinterpret_cast<T *>(output_cu->data.dptr),
......@@ -359,7 +359,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
const transformer_engine::Tensor *prob_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T,
nvte_unpermute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr),
reinterpret_cast<T *>(output_cu->data.dptr),
......
......@@ -52,18 +52,11 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr,
num_tokens * topK);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input.scalar_type();
// Output buffer alloc
num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK;
at::Tensor permuted_output = torch::empty(
{num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
at::Tensor permuted_output =
torch::empty({num_out_tokens, num_cols},
torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false));
at::Tensor row_id_map = torch::empty(
{num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
......@@ -100,17 +93,10 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
using namespace transformer_engine::pytorch;
int num_cols = input.size(1);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input.scalar_type();
// Output buffer alloc
at::Tensor unpermuted_output = torch::empty(
{num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
at::Tensor unpermuted_output =
torch::empty({num_tokens, num_cols},
torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false));
auto stream = at::cuda::getCurrentCUDAStream().stream();
......@@ -136,17 +122,10 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0);
int num_cols = input_bwd.size(1);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input_bwd.scalar_type();
// Output buffer alloc
at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols},
torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
at::Tensor act_grad =
torch::empty({input_fwd.size(0), num_cols},
torch::dtype(input_bwd.scalar_type()).device(torch::kCUDA).requires_grad(false));
at::Tensor prob_grad = torch::empty(
{num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
......
......@@ -4,14 +4,16 @@
"""MoE Permutaion API"""
import warnings
from typing import Tuple
from typing import Optional, Tuple
import torch
import transformer_engine_torch as tex
import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
__all__ = [
"moe_permute",
......@@ -46,17 +48,7 @@ class _moe_permute_index_map(torch.autograd.Function):
assert inp.size(0) == index.size(0), "Permute not possible"
# Data type check
fp8 = isinstance(inp, Float8Tensor)
if fp8:
assert (
inp._quantizer.scale.ndim == 0
), "Only one factor scaling per tensor (Delayed Scaling) supported by moe_permute."
dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
dtype = TE_DType[inp.dtype]
dtype = TE_DType[inp.dtype]
if index.dtype != torch.int32:
warnings.warn(
f"The data type of the input `index` of Permute is {index.dtype}! "
......@@ -80,19 +72,9 @@ class _moe_permute_index_map(torch.autograd.Function):
_moe_permute_index_map.max_expanded_token_num,
)
if fp8:
permuted_act = Float8Tensor(
data=permuted_act,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv,
shape=permuted_act.shape,
dtype=fake_dtype,
)
ctx.row_id_map = row_id_map
ctx.num_tokens = index.size(0)
ctx.topK = index.size(1)
ctx.fp8 = fp8
return permuted_act, row_id_map
@staticmethod
......@@ -109,30 +91,12 @@ class _moe_permute_index_map(torch.autograd.Function):
if not permuted_act_grad.is_contiguous():
permuted_act_grad = permuted_act_grad.contiguous()
if ctx.fp8:
assert isinstance(
permuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
fake_dtype = permuted_act_grad.dtype
permuted_act_grad = permuted_act_grad._data
else:
dtype = TE_DType[permuted_act_grad.dtype]
dtype = TE_DType[permuted_act_grad.dtype]
act_grad = None
if ctx.needs_input_grad[0]:
act_grad = tex.moe_permute_bwd(
permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
)
if ctx.fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv * ctx.topK,
shape=act_grad.shape,
dtype=fake_dtype,
)
return act_grad, None, None, None
......@@ -176,14 +140,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
# Data type check
fp8 = isinstance(inp, Float8Tensor)
if fp8:
dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
dtype = TE_DType[inp.dtype]
dtype = TE_DType[inp.dtype]
if row_id_map.dtype != torch.int32:
warnings.warn(
f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
......@@ -193,17 +150,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)
if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv,
shape=unpermuted_output.shape,
dtype=fake_dtype,
)
ctx.save_for_backward(inp, row_id_map, probs)
ctx.fp8 = fp8
return unpermuted_output
@staticmethod
......@@ -219,17 +166,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
if not unpermuted_act_grad.is_contiguous():
unpermuted_act_grad = unpermuted_act_grad.contiguous()
if ctx.fp8:
assert isinstance(
unpermuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
fake_dtype = unpermuted_act_grad.dtype
unpermuted_act_grad = unpermuted_act_grad._data
else:
dtype = TE_DType[unpermuted_act_grad.dtype]
dtype = TE_DType[unpermuted_act_grad.dtype]
inp, row_id_map, probs = ctx.saved_tensors
act_grad = None
......@@ -238,14 +175,6 @@ class _moe_unpermute_index_map(torch.autograd.Function):
act_grad, prob_grad = tex.moe_unpermute_bwd(
unpermuted_act_grad, inp, dtype, row_id_map, probs
)
if ctx.fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv,
shape=act_grad.shape,
dtype=fake_dtype,
)
if not ctx.needs_input_grad[2]:
prob_grad = None
......@@ -282,29 +211,86 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)
fp8 = isinstance(inp, Float8Tensor)
fp8 = isinstance(inp, QuantizedTensor)
per_tensor_recipe = isinstance(inp, Float8Tensor)
blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor)
mxfp8_recipe = isinstance(inp, MXFP8Tensor)
if fp8:
fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
output, permuted_probs = triton_permutation.permute_with_mask_map(
# blockwise scaling
if blockwise_recipe:
fp8_scale = inp._rowwise_scale_inv.T.contiguous()
scale_hidden_dim = fp8_scale.shape[1]
assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
inp = inp._rowwise_data
# mxfp8 scaling
elif mxfp8_recipe:
fp8_scale = inp._rowwise_scale_inv.contiguous()
scale_hidden_dim = fp8_scale.shape[1]
assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
inp = inp._rowwise_data
# per-tensor scaling
elif per_tensor_recipe:
# Kernel does not need scale in per-tensor scaling
fp8_scale = None
scale_hidden_dim = None
fp8_scale_inv = inp._scale_inv
inp = inp._data
else:
raise ValueError("Unsupported FP8 recipe")
else:
fp8_scale = None
fp8_dtype = None
scale_hidden_dim = None
output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map(
inp,
row_id_map,
probs,
fp8_scale,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
scale_hidden_dim,
)
if fp8:
output = Float8Tensor(
data=output,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=output.shape,
dtype=fake_dtype,
)
if per_tensor_recipe:
output = Float8Tensor(
data=output,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=output.shape,
dtype=fake_dtype,
)
elif blockwise_recipe:
output = Float8BlockwiseQTensor(
shape=output.shape,
dtype=fake_dtype,
rowwise_data=output,
rowwise_scale_inv=permuted_scale.T.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
quantizer=None,
is_2D_scaled=False,
requires_grad=output.requires_grad,
)
elif mxfp8_recipe:
output = MXFP8Tensor(
shape=output.shape,
dtype=fake_dtype,
fp8_dtype=fp8_dtype,
rowwise_data=output,
rowwise_scale_inv=permuted_scale.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
quantizer=None,
requires_grad=output.requires_grad,
)
ctx.save_for_backward(row_id_map)
ctx.num_experts = num_experts
......@@ -327,14 +313,9 @@ class _moe_permute_mask_map(torch.autograd.Function):
probs_grad = None
if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors
fp8 = isinstance(permuted_act_grad, Float8Tensor)
if fp8:
fp8_dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
fake_dtype = permuted_act_grad.dtype
permuted_act_grad = permuted_act_grad._data
else:
fp8_dtype = None
assert not isinstance(
permuted_act_grad, QuantizedTensor
), "The backward of moe_permute does not support FP8."
act_grad, probs_grad = triton_permutation.unpermute_with_mask_map(
permuted_act_grad,
row_id_map,
......@@ -343,16 +324,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
ctx.num_tokens,
ctx.num_experts,
ctx.hidden_size,
fp8_dtype,
)
if fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv * ctx.num_experts,
shape=act_grad.shape,
dtype=fake_dtype,
)
if not ctx.needs_input_grad[3]:
probs_grad = None
return act_grad, None, None, probs_grad
......@@ -366,8 +338,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
ctx,
inp: torch.Tensor,
row_id_map: torch.Tensor,
merging_probs: torch.Tensor,
restore_shape: torch.Size,
merging_probs: Optional[torch.Tensor],
restore_shape: Optional[torch.Size],
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if not inp.numel():
......@@ -387,17 +359,9 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
fp8 = isinstance(inp, Float8Tensor)
if fp8:
fp8_dtype = inp._fp8_dtype
if not with_probs:
fp8_scale_inv = inp._scale_inv * num_experts
else:
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
fp8_dtype = None
assert not isinstance(
inp, QuantizedTensor
), "The forward of moe_unpermute does not support FP8."
unpermuted_output, _ = triton_permutation.unpermute_with_mask_map(
inp,
row_id_map,
......@@ -406,16 +370,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
num_tokens,
num_experts,
hidden_size,
fp8_dtype=fp8_dtype,
)
if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=unpermuted_output.shape,
dtype=fake_dtype,
)
if with_probs:
ctx.save_for_backward(inp, row_id_map, merging_probs)
......@@ -442,16 +397,44 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
else:
(row_id_map,) = ctx.saved_tensors
fp8 = isinstance(unpermuted_act_grad, Float8Tensor)
fp8 = isinstance(unpermuted_act_grad, QuantizedTensor)
per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor)
blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor)
mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor)
if fp8:
fp8_dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
fake_dtype = unpermuted_act_grad.dtype
unpermuted_act_grad = unpermuted_act_grad._data
# per-tensor scaling
if per_tensor_recipe:
# Kernel does not need scale in per-tensor scaling
fp8_scale = None
scale_hidden_dim = None
fp8_scale_inv = unpermuted_act_grad._scale_inv
unpermuted_act_grad = unpermuted_act_grad._data
# blockwise scaling
elif blockwise_recipe:
fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous()
unpermuted_act_grad = unpermuted_act_grad._rowwise_data
scale_hidden_dim = fp8_scale.shape[1]
assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
# mxfp8 scaling
elif mxfp8_recipe:
fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous()
unpermuted_act_grad = unpermuted_act_grad._rowwise_data
scale_hidden_dim = fp8_scale.shape[1]
assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
else:
raise ValueError("Unsupported FP8 recipe")
else:
scale_hidden_dim = None
fp8_dtype = None
fp8_scale = None
if ctx.with_probs:
assert (
not fp8
), "The backward of moe_unpermute with merging probs does not support FP8."
act_grad, probs_grad = (
triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs(
unpermuted_act_grad,
......@@ -462,28 +445,55 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
ctx.num_experts,
ctx.num_permuted_tokens,
ctx.hidden_size,
fp8_dtype,
)
)
else:
act_grad, _ = triton_permutation.permute_with_mask_map(
act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map(
unpermuted_act_grad,
row_id_map,
None,
fp8_scale,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
ctx.hidden_size,
scale_hidden_dim,
)
if fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=act_grad.shape,
dtype=fake_dtype,
)
if per_tensor_recipe:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=act_grad.shape,
dtype=fake_dtype,
)
elif blockwise_recipe:
act_grad = Float8BlockwiseQTensor(
shape=act_grad.shape,
dtype=fake_dtype,
rowwise_data=act_grad,
rowwise_scale_inv=permuted_scale.T.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
quantizer=None,
is_2D_scaled=False,
requires_grad=act_grad.requires_grad,
)
elif mxfp8_recipe:
act_grad = MXFP8Tensor(
shape=act_grad.shape,
dtype=fake_dtype,
fp8_dtype=fp8_dtype,
rowwise_data=act_grad,
rowwise_scale_inv=permuted_scale.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
quantizer=None,
requires_grad=act_grad.requires_grad,
)
if not ctx.needs_input_grad[2]:
probs_grad = None
......@@ -568,10 +578,10 @@ def moe_permute_with_probs(
def moe_unpermute(
inp: torch.Tensor,
row_id_map: torch.Tensor,
merging_probs: torch.Tensor = None,
restore_shape: torch.Tensor = None,
merging_probs: Optional[torch.Tensor] = None,
restore_shape: Optional[torch.Size] = None,
map_type: str = "mask",
probs: torch.Tensor = None,
probs: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
......@@ -588,7 +598,7 @@ def moe_unpermute(
The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
restore_shape: torch.Tensor
restore_shape: torch.Size, default = None
The output shape after the unpermute operation.
map_type: str, default = 'mask'
Type of the routing map tensor. Should be the same as the value passed to moe_permute.
......
......@@ -10,8 +10,6 @@ import torch
import triton
import triton.language as tl
from transformer_engine_torch import DType as TE_DType
@triton.jit
def _row_id_map_pass_1_kernel(
......@@ -116,11 +114,14 @@ def _permute_kernel(
output_ptr,
row_id_map_ptr,
probs_ptr,
scale_ptr,
permuted_probs_ptr,
permuted_scale_ptr,
# sizes
num_tokens,
num_experts,
hidden_size,
scale_hidden_dim,
# strides
stride_input_token,
stride_input_hidden,
......@@ -128,9 +129,14 @@ def _permute_kernel(
stride_output_hidden,
stride_probs_token,
stride_probs_expert,
stride_scale_token,
stride_scale_hidden,
stride_permuted_probs_token,
stride_permuted_scale_token,
stride_permuted_scale_hidden,
# metas
PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
......@@ -140,11 +146,21 @@ def _permute_kernel(
mask = cur_off < hidden_size
input_off = pid * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim
scale_off = pid * stride_scale_token + cur_off * stride_scale_hidden
scale = tl.load(scale_ptr + scale_off, mask=mask_scale)
for expert_idx in range(num_experts):
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if dst_row != -1:
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
tl.store(output_ptr + output_off, inp, mask=mask)
if PERMUTE_SCALE:
permuted_scale_off = (
dst_row * stride_permuted_scale_token
+ cur_off * stride_permuted_scale_hidden
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS:
if cur_pos == 0:
prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert
......@@ -173,10 +189,12 @@ def permute_with_mask_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
probs: torch.Tensor,
scale: torch.Tensor,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
hidden_size: int,
scale_hidden_dim: int,
):
# pylint: disable=missing-function-docstring
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
......@@ -184,26 +202,42 @@ def permute_with_mask_map(
permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
else:
permuted_probs = None
if scale is not None:
permuted_scale = torch.empty(
(num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda"
)
else:
permuted_scale = None
grid = (num_tokens,)
_permute_kernel[grid](
inp,
output,
row_id_map,
probs,
scale,
permuted_probs,
permuted_scale,
num_tokens,
num_experts,
hidden_size,
scale_hidden_dim,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
probs.stride(0) if probs is not None else None,
probs.stride(1) if probs is not None else None,
scale.stride(0) if scale is not None else None,
scale.stride(1) if scale is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None,
permuted_scale.stride(0) if permuted_scale is not None else None,
permuted_scale.stride(1) if permuted_scale is not None else None,
PERMUTE_PROBS=probs is not None,
PERMUTE_SCALE=scale is not None,
)
return output, permuted_probs
return output, permuted_scale, permuted_probs
@triton.jit
......@@ -232,18 +266,9 @@ def _unpermute_kernel(
# metas
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
if FP8_DTYPE == "e5m2":
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
data_type = input_ptr.dtype.element_ty
assert FP8_DTYPE is None
data_type = input_ptr.dtype.element_ty
compute_type = tl.float32
pid = tl.program_id(0)
......@@ -257,8 +282,6 @@ def _unpermute_kernel(
if src_row != -1:
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
merging_prob_off = (
......@@ -279,14 +302,7 @@ def _unpermute_kernel(
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
else:
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0)
if FP8_DTYPE is not None:
if not WITH_MERGING_PROBS:
# Directly adding these value may cause overflow for fp8, we scale it here.
# The outside fp8_scale_inv is also scaled in the meantime.
accumulator /= num_experts
accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True)
else:
accumulator = accumulator.to(data_type)
accumulator = accumulator.to(data_type)
output_off = pid * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask)
current_start += BLOCK_SIZE
......@@ -315,15 +331,8 @@ def unpermute_with_mask_map(
num_tokens: int,
num_experts: int,
hidden_size: int,
fp8_dtype: TE_DType,
):
# pylint: disable=missing-function-docstring
if fp8_dtype == TE_DType.kFloat8E5M2:
fp8_dtype = "e5m2"
elif fp8_dtype == TE_DType.kFloat8E4M3:
fp8_dtype = "e4m3"
else:
fp8_dtype = None
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if permuted_probs is not None:
unpermuted_probs = torch.empty(
......@@ -353,7 +362,6 @@ def unpermute_with_mask_map(
unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None,
FP8_DTYPE=fp8_dtype,
)
return output, unpermuted_probs
......@@ -383,18 +391,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_grad_token,
stride_merging_probs_grad_expert,
# metas
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
if FP8_DTYPE == "e5m2":
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
data_type = fwd_output_grad_ptr.dtype.element_ty
assert FP8_DTYPE is None
data_type = fwd_output_grad_ptr.dtype.element_ty
compute_type = tl.float32
pid = tl.program_id(0)
......@@ -411,8 +410,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
+ current_offset * stride_fwd_output_grad_hidden
)
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True)
inp = inp.to(compute_type)
merging_prob_off = (
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
......@@ -420,8 +417,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
output = inp * merging_prob
output = output.to(data_type)
if FP8_DTYPE is not None:
output = output.to(pytorch_tensor_dtype, bitcast=True)
output_off = (
dst_row * stride_fwd_input_grad_token
+ current_offset * stride_fwd_input_grad_hidden
......@@ -432,8 +427,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
)
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
if FP8_DTYPE is not None:
fwd_input = fwd_input.to(data_type, bitcast=True)
prob_grad_accum += fwd_input.to(compute_type) * inp
current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
......@@ -474,15 +467,8 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
num_experts: int,
num_out_tokens: int,
hidden_size: int,
fp8_dtype: TE_DType,
):
# pylint: disable=missing-function-docstring
if fp8_dtype == TE_DType.kFloat8E5M2:
fp8_dtype = "e5m2"
elif fp8_dtype == TE_DType.kFloat8E4M3:
fp8_dtype = "e4m3"
else:
fp8_dtype = None
act_grad = torch.empty(
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
)
......@@ -510,7 +496,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs.stride(1),
merging_probs_grad.stride(0),
merging_probs_grad.stride(1),
fp8_dtype,
)
return act_grad, merging_probs_grad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment