Unverified Commit c8e7cc02 authored by Autumn1998's avatar Autumn1998 Committed by GitHub
Browse files

[MoE] Support new fp8 recipes for permute_fusion (#1649)



* add support for new recipe on permute_fusion, rm fp unpermute
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* remove fp8 from index map
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* skip unsupported tests
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatartongliu <tongliu@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent d9eb0582
This diff is collapsed.
...@@ -333,7 +333,7 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so ...@@ -333,7 +333,7 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
const transformer_engine::Tensor *input_fwd_cu = const transformer_engine::Tensor *input_fwd_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input_fwd); reinterpret_cast<const transformer_engine::Tensor *>(input_fwd);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T, input_cu->data.dtype, T,
nvte_permute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr), nvte_permute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr),
reinterpret_cast<T *>(output_cu->data.dptr), reinterpret_cast<T *>(output_cu->data.dptr),
...@@ -359,7 +359,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id ...@@ -359,7 +359,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
const transformer_engine::Tensor *prob_cu = const transformer_engine::Tensor *prob_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob); reinterpret_cast<const transformer_engine::Tensor *>(prob);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T, input_cu->data.dtype, T,
nvte_unpermute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr), nvte_unpermute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr),
reinterpret_cast<T *>(output_cu->data.dptr), reinterpret_cast<T *>(output_cu->data.dptr),
......
...@@ -52,18 +52,11 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( ...@@ -52,18 +52,11 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr, sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr,
num_tokens * topK); num_tokens * topK);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input.scalar_type();
// Output buffer alloc // Output buffer alloc
num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK; num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK;
at::Tensor permuted_output = torch::empty( at::Tensor permuted_output =
{num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); torch::empty({num_out_tokens, num_cols},
torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false));
at::Tensor row_id_map = torch::empty( at::Tensor row_id_map = torch::empty(
{num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false)); {num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
...@@ -100,17 +93,10 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d ...@@ -100,17 +93,10 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
int num_cols = input.size(1); int num_cols = input.size(1);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input.scalar_type();
// Output buffer alloc // Output buffer alloc
at::Tensor unpermuted_output = torch::empty( at::Tensor unpermuted_output =
{num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); torch::empty({num_tokens, num_cols},
torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false));
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
...@@ -136,17 +122,10 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T ...@@ -136,17 +122,10 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0);
int num_cols = input_bwd.size(1); int num_cols = input_bwd.size(1);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input_bwd.scalar_type();
// Output buffer alloc // Output buffer alloc
at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols}, at::Tensor act_grad =
torch::dtype(_st).device(torch::kCUDA).requires_grad(false)); torch::empty({input_fwd.size(0), num_cols},
torch::dtype(input_bwd.scalar_type()).device(torch::kCUDA).requires_grad(false));
at::Tensor prob_grad = torch::empty( at::Tensor prob_grad = torch::empty(
{num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); {num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
......
...@@ -4,14 +4,16 @@ ...@@ -4,14 +4,16 @@
"""MoE Permutaion API""" """MoE Permutaion API"""
import warnings import warnings
from typing import Tuple from typing import Optional, Tuple
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
import transformer_engine.pytorch.triton.permutation as triton_permutation import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
__all__ = [ __all__ = [
"moe_permute", "moe_permute",
...@@ -46,17 +48,7 @@ class _moe_permute_index_map(torch.autograd.Function): ...@@ -46,17 +48,7 @@ class _moe_permute_index_map(torch.autograd.Function):
assert inp.size(0) == index.size(0), "Permute not possible" assert inp.size(0) == index.size(0), "Permute not possible"
# Data type check # Data type check
fp8 = isinstance(inp, Float8Tensor) dtype = TE_DType[inp.dtype]
if fp8:
assert (
inp._quantizer.scale.ndim == 0
), "Only one factor scaling per tensor (Delayed Scaling) supported by moe_permute."
dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
dtype = TE_DType[inp.dtype]
if index.dtype != torch.int32: if index.dtype != torch.int32:
warnings.warn( warnings.warn(
f"The data type of the input `index` of Permute is {index.dtype}! " f"The data type of the input `index` of Permute is {index.dtype}! "
...@@ -80,19 +72,9 @@ class _moe_permute_index_map(torch.autograd.Function): ...@@ -80,19 +72,9 @@ class _moe_permute_index_map(torch.autograd.Function):
_moe_permute_index_map.max_expanded_token_num, _moe_permute_index_map.max_expanded_token_num,
) )
if fp8:
permuted_act = Float8Tensor(
data=permuted_act,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv,
shape=permuted_act.shape,
dtype=fake_dtype,
)
ctx.row_id_map = row_id_map ctx.row_id_map = row_id_map
ctx.num_tokens = index.size(0) ctx.num_tokens = index.size(0)
ctx.topK = index.size(1) ctx.topK = index.size(1)
ctx.fp8 = fp8
return permuted_act, row_id_map return permuted_act, row_id_map
@staticmethod @staticmethod
...@@ -109,30 +91,12 @@ class _moe_permute_index_map(torch.autograd.Function): ...@@ -109,30 +91,12 @@ class _moe_permute_index_map(torch.autograd.Function):
if not permuted_act_grad.is_contiguous(): if not permuted_act_grad.is_contiguous():
permuted_act_grad = permuted_act_grad.contiguous() permuted_act_grad = permuted_act_grad.contiguous()
if ctx.fp8: dtype = TE_DType[permuted_act_grad.dtype]
assert isinstance(
permuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
fake_dtype = permuted_act_grad.dtype
permuted_act_grad = permuted_act_grad._data
else:
dtype = TE_DType[permuted_act_grad.dtype]
act_grad = None act_grad = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
act_grad = tex.moe_permute_bwd( act_grad = tex.moe_permute_bwd(
permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
) )
if ctx.fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv * ctx.topK,
shape=act_grad.shape,
dtype=fake_dtype,
)
return act_grad, None, None, None return act_grad, None, None, None
...@@ -176,14 +140,7 @@ class _moe_unpermute_index_map(torch.autograd.Function): ...@@ -176,14 +140,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
assert row_id_map.is_cuda, "TransformerEngine needs CUDA." assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
# Data type check # Data type check
fp8 = isinstance(inp, Float8Tensor) dtype = TE_DType[inp.dtype]
if fp8:
dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
dtype = TE_DType[inp.dtype]
if row_id_map.dtype != torch.int32: if row_id_map.dtype != torch.int32:
warnings.warn( warnings.warn(
f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
...@@ -193,17 +150,7 @@ class _moe_unpermute_index_map(torch.autograd.Function): ...@@ -193,17 +150,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)
if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv,
shape=unpermuted_output.shape,
dtype=fake_dtype,
)
ctx.save_for_backward(inp, row_id_map, probs) ctx.save_for_backward(inp, row_id_map, probs)
ctx.fp8 = fp8
return unpermuted_output return unpermuted_output
@staticmethod @staticmethod
...@@ -219,17 +166,7 @@ class _moe_unpermute_index_map(torch.autograd.Function): ...@@ -219,17 +166,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
if not unpermuted_act_grad.is_contiguous(): if not unpermuted_act_grad.is_contiguous():
unpermuted_act_grad = unpermuted_act_grad.contiguous() unpermuted_act_grad = unpermuted_act_grad.contiguous()
if ctx.fp8: dtype = TE_DType[unpermuted_act_grad.dtype]
assert isinstance(
unpermuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
fake_dtype = unpermuted_act_grad.dtype
unpermuted_act_grad = unpermuted_act_grad._data
else:
dtype = TE_DType[unpermuted_act_grad.dtype]
inp, row_id_map, probs = ctx.saved_tensors inp, row_id_map, probs = ctx.saved_tensors
act_grad = None act_grad = None
...@@ -238,14 +175,6 @@ class _moe_unpermute_index_map(torch.autograd.Function): ...@@ -238,14 +175,6 @@ class _moe_unpermute_index_map(torch.autograd.Function):
act_grad, prob_grad = tex.moe_unpermute_bwd( act_grad, prob_grad = tex.moe_unpermute_bwd(
unpermuted_act_grad, inp, dtype, row_id_map, probs unpermuted_act_grad, inp, dtype, row_id_map, probs
) )
if ctx.fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv,
shape=act_grad.shape,
dtype=fake_dtype,
)
if not ctx.needs_input_grad[2]: if not ctx.needs_input_grad[2]:
prob_grad = None prob_grad = None
...@@ -282,29 +211,86 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -282,29 +211,86 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)
fp8 = isinstance(inp, Float8Tensor) fp8 = isinstance(inp, QuantizedTensor)
per_tensor_recipe = isinstance(inp, Float8Tensor)
blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor)
mxfp8_recipe = isinstance(inp, MXFP8Tensor)
if fp8: if fp8:
fp8_dtype = inp._fp8_dtype fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype fake_dtype = inp.dtype
inp = inp._data # blockwise scaling
output, permuted_probs = triton_permutation.permute_with_mask_map( if blockwise_recipe:
fp8_scale = inp._rowwise_scale_inv.T.contiguous()
scale_hidden_dim = fp8_scale.shape[1]
assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
inp = inp._rowwise_data
# mxfp8 scaling
elif mxfp8_recipe:
fp8_scale = inp._rowwise_scale_inv.contiguous()
scale_hidden_dim = fp8_scale.shape[1]
assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
inp = inp._rowwise_data
# per-tensor scaling
elif per_tensor_recipe:
# Kernel does not need scale in per-tensor scaling
fp8_scale = None
scale_hidden_dim = None
fp8_scale_inv = inp._scale_inv
inp = inp._data
else:
raise ValueError("Unsupported FP8 recipe")
else:
fp8_scale = None
fp8_dtype = None
scale_hidden_dim = None
output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map(
inp, inp,
row_id_map, row_id_map,
probs, probs,
fp8_scale,
num_tokens, num_tokens,
num_experts, num_experts,
num_out_tokens, num_out_tokens,
hidden_size, hidden_size,
scale_hidden_dim,
) )
if fp8: if fp8:
output = Float8Tensor( if per_tensor_recipe:
data=output, output = Float8Tensor(
fp8_dtype=fp8_dtype, data=output,
fp8_scale_inv=fp8_scale_inv, fp8_dtype=fp8_dtype,
shape=output.shape, fp8_scale_inv=fp8_scale_inv,
dtype=fake_dtype, shape=output.shape,
) dtype=fake_dtype,
)
elif blockwise_recipe:
output = Float8BlockwiseQTensor(
shape=output.shape,
dtype=fake_dtype,
rowwise_data=output,
rowwise_scale_inv=permuted_scale.T.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
quantizer=None,
is_2D_scaled=False,
requires_grad=output.requires_grad,
)
elif mxfp8_recipe:
output = MXFP8Tensor(
shape=output.shape,
dtype=fake_dtype,
fp8_dtype=fp8_dtype,
rowwise_data=output,
rowwise_scale_inv=permuted_scale.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
quantizer=None,
requires_grad=output.requires_grad,
)
ctx.save_for_backward(row_id_map) ctx.save_for_backward(row_id_map)
ctx.num_experts = num_experts ctx.num_experts = num_experts
...@@ -327,14 +313,9 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -327,14 +313,9 @@ class _moe_permute_mask_map(torch.autograd.Function):
probs_grad = None probs_grad = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors (row_id_map,) = ctx.saved_tensors
fp8 = isinstance(permuted_act_grad, Float8Tensor) assert not isinstance(
if fp8: permuted_act_grad, QuantizedTensor
fp8_dtype = permuted_act_grad._fp8_dtype ), "The backward of moe_permute does not support FP8."
fp8_scale_inv = permuted_act_grad._scale_inv
fake_dtype = permuted_act_grad.dtype
permuted_act_grad = permuted_act_grad._data
else:
fp8_dtype = None
act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( act_grad, probs_grad = triton_permutation.unpermute_with_mask_map(
permuted_act_grad, permuted_act_grad,
row_id_map, row_id_map,
...@@ -343,16 +324,7 @@ class _moe_permute_mask_map(torch.autograd.Function): ...@@ -343,16 +324,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
ctx.num_tokens, ctx.num_tokens,
ctx.num_experts, ctx.num_experts,
ctx.hidden_size, ctx.hidden_size,
fp8_dtype,
) )
if fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv * ctx.num_experts,
shape=act_grad.shape,
dtype=fake_dtype,
)
if not ctx.needs_input_grad[3]: if not ctx.needs_input_grad[3]:
probs_grad = None probs_grad = None
return act_grad, None, None, probs_grad return act_grad, None, None, probs_grad
...@@ -366,8 +338,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -366,8 +338,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
merging_probs: torch.Tensor, merging_probs: Optional[torch.Tensor],
restore_shape: torch.Size, restore_shape: Optional[torch.Size],
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if not inp.numel(): if not inp.numel():
...@@ -387,17 +359,9 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -387,17 +359,9 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
assert row_id_map.is_cuda, "TransformerEngine needs CUDA." assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
fp8 = isinstance(inp, Float8Tensor) assert not isinstance(
if fp8: inp, QuantizedTensor
fp8_dtype = inp._fp8_dtype ), "The forward of moe_unpermute does not support FP8."
if not with_probs:
fp8_scale_inv = inp._scale_inv * num_experts
else:
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
fp8_dtype = None
unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( unpermuted_output, _ = triton_permutation.unpermute_with_mask_map(
inp, inp,
row_id_map, row_id_map,
...@@ -406,16 +370,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -406,16 +370,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
num_tokens, num_tokens,
num_experts, num_experts,
hidden_size, hidden_size,
fp8_dtype=fp8_dtype,
) )
if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=unpermuted_output.shape,
dtype=fake_dtype,
)
if with_probs: if with_probs:
ctx.save_for_backward(inp, row_id_map, merging_probs) ctx.save_for_backward(inp, row_id_map, merging_probs)
...@@ -442,16 +397,44 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -442,16 +397,44 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
else: else:
(row_id_map,) = ctx.saved_tensors (row_id_map,) = ctx.saved_tensors
fp8 = isinstance(unpermuted_act_grad, Float8Tensor) fp8 = isinstance(unpermuted_act_grad, QuantizedTensor)
per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor)
blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor)
mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor)
if fp8: if fp8:
fp8_dtype = unpermuted_act_grad._fp8_dtype fp8_dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
fake_dtype = unpermuted_act_grad.dtype fake_dtype = unpermuted_act_grad.dtype
unpermuted_act_grad = unpermuted_act_grad._data # per-tensor scaling
if per_tensor_recipe:
# Kernel does not need scale in per-tensor scaling
fp8_scale = None
scale_hidden_dim = None
fp8_scale_inv = unpermuted_act_grad._scale_inv
unpermuted_act_grad = unpermuted_act_grad._data
# blockwise scaling
elif blockwise_recipe:
fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous()
unpermuted_act_grad = unpermuted_act_grad._rowwise_data
scale_hidden_dim = fp8_scale.shape[1]
assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
# mxfp8 scaling
elif mxfp8_recipe:
fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous()
unpermuted_act_grad = unpermuted_act_grad._rowwise_data
scale_hidden_dim = fp8_scale.shape[1]
assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
else:
raise ValueError("Unsupported FP8 recipe")
else: else:
scale_hidden_dim = None
fp8_dtype = None fp8_dtype = None
fp8_scale = None
if ctx.with_probs: if ctx.with_probs:
assert (
not fp8
), "The backward of moe_unpermute with merging probs does not support FP8."
act_grad, probs_grad = ( act_grad, probs_grad = (
triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs(
unpermuted_act_grad, unpermuted_act_grad,
...@@ -462,28 +445,55 @@ class _moe_unpermute_mask_map(torch.autograd.Function): ...@@ -462,28 +445,55 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
ctx.num_experts, ctx.num_experts,
ctx.num_permuted_tokens, ctx.num_permuted_tokens,
ctx.hidden_size, ctx.hidden_size,
fp8_dtype,
) )
) )
else: else:
act_grad, _ = triton_permutation.permute_with_mask_map( act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map(
unpermuted_act_grad, unpermuted_act_grad,
row_id_map, row_id_map,
None, None,
fp8_scale,
ctx.num_tokens, ctx.num_tokens,
ctx.num_experts, ctx.num_experts,
ctx.num_permuted_tokens, ctx.num_permuted_tokens,
ctx.hidden_size, ctx.hidden_size,
scale_hidden_dim,
) )
if fp8: if fp8:
act_grad = Float8Tensor( if per_tensor_recipe:
data=act_grad, act_grad = Float8Tensor(
fp8_dtype=fp8_dtype, data=act_grad,
fp8_scale_inv=fp8_scale_inv, fp8_dtype=fp8_dtype,
shape=act_grad.shape, fp8_scale_inv=fp8_scale_inv,
dtype=fake_dtype, shape=act_grad.shape,
) dtype=fake_dtype,
)
elif blockwise_recipe:
act_grad = Float8BlockwiseQTensor(
shape=act_grad.shape,
dtype=fake_dtype,
rowwise_data=act_grad,
rowwise_scale_inv=permuted_scale.T.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
quantizer=None,
is_2D_scaled=False,
requires_grad=act_grad.requires_grad,
)
elif mxfp8_recipe:
act_grad = MXFP8Tensor(
shape=act_grad.shape,
dtype=fake_dtype,
fp8_dtype=fp8_dtype,
rowwise_data=act_grad,
rowwise_scale_inv=permuted_scale.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
quantizer=None,
requires_grad=act_grad.requires_grad,
)
if not ctx.needs_input_grad[2]: if not ctx.needs_input_grad[2]:
probs_grad = None probs_grad = None
...@@ -568,10 +578,10 @@ def moe_permute_with_probs( ...@@ -568,10 +578,10 @@ def moe_permute_with_probs(
def moe_unpermute( def moe_unpermute(
inp: torch.Tensor, inp: torch.Tensor,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
merging_probs: torch.Tensor = None, merging_probs: Optional[torch.Tensor] = None,
restore_shape: torch.Tensor = None, restore_shape: Optional[torch.Size] = None,
map_type: str = "mask", map_type: str = "mask",
probs: torch.Tensor = None, probs: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
...@@ -588,7 +598,7 @@ def moe_unpermute( ...@@ -588,7 +598,7 @@ def moe_unpermute(
The tensor of probabilities corresponding to the permuted tokens. If provided, The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities. the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
restore_shape: torch.Tensor restore_shape: torch.Size, default = None
The output shape after the unpermute operation. The output shape after the unpermute operation.
map_type: str, default = 'mask' map_type: str, default = 'mask'
Type of the routing map tensor. Should be the same as the value passed to moe_permute. Type of the routing map tensor. Should be the same as the value passed to moe_permute.
......
...@@ -10,8 +10,6 @@ import torch ...@@ -10,8 +10,6 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from transformer_engine_torch import DType as TE_DType
@triton.jit @triton.jit
def _row_id_map_pass_1_kernel( def _row_id_map_pass_1_kernel(
...@@ -116,11 +114,14 @@ def _permute_kernel( ...@@ -116,11 +114,14 @@ def _permute_kernel(
output_ptr, output_ptr,
row_id_map_ptr, row_id_map_ptr,
probs_ptr, probs_ptr,
scale_ptr,
permuted_probs_ptr, permuted_probs_ptr,
permuted_scale_ptr,
# sizes # sizes
num_tokens, num_tokens,
num_experts, num_experts,
hidden_size, hidden_size,
scale_hidden_dim,
# strides # strides
stride_input_token, stride_input_token,
stride_input_hidden, stride_input_hidden,
...@@ -128,9 +129,14 @@ def _permute_kernel( ...@@ -128,9 +129,14 @@ def _permute_kernel(
stride_output_hidden, stride_output_hidden,
stride_probs_token, stride_probs_token,
stride_probs_expert, stride_probs_expert,
stride_scale_token,
stride_scale_hidden,
stride_permuted_probs_token, stride_permuted_probs_token,
stride_permuted_scale_token,
stride_permuted_scale_hidden,
# metas # metas
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
pid = tl.program_id(0) pid = tl.program_id(0)
...@@ -140,11 +146,21 @@ def _permute_kernel( ...@@ -140,11 +146,21 @@ def _permute_kernel(
mask = cur_off < hidden_size mask = cur_off < hidden_size
input_off = pid * stride_input_token + cur_off * stride_input_hidden input_off = pid * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask) inp = tl.load(input_ptr + input_off, mask=mask)
if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim
scale_off = pid * stride_scale_token + cur_off * stride_scale_hidden
scale = tl.load(scale_ptr + scale_off, mask=mask_scale)
for expert_idx in range(num_experts): for expert_idx in range(num_experts):
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if dst_row != -1: if dst_row != -1:
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
tl.store(output_ptr + output_off, inp, mask=mask) tl.store(output_ptr + output_off, inp, mask=mask)
if PERMUTE_SCALE:
permuted_scale_off = (
dst_row * stride_permuted_scale_token
+ cur_off * stride_permuted_scale_hidden
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS: if PERMUTE_PROBS:
if cur_pos == 0: if cur_pos == 0:
prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert
...@@ -173,10 +189,12 @@ def permute_with_mask_map( ...@@ -173,10 +189,12 @@ def permute_with_mask_map(
inp: torch.Tensor, inp: torch.Tensor,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
probs: torch.Tensor, probs: torch.Tensor,
scale: torch.Tensor,
num_tokens: int, num_tokens: int,
num_experts: int, num_experts: int,
num_out_tokens: int, num_out_tokens: int,
hidden_size: int, hidden_size: int,
scale_hidden_dim: int,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
...@@ -184,26 +202,42 @@ def permute_with_mask_map( ...@@ -184,26 +202,42 @@ def permute_with_mask_map(
permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
else: else:
permuted_probs = None permuted_probs = None
if scale is not None:
permuted_scale = torch.empty(
(num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda"
)
else:
permuted_scale = None
grid = (num_tokens,) grid = (num_tokens,)
_permute_kernel[grid]( _permute_kernel[grid](
inp, inp,
output, output,
row_id_map, row_id_map,
probs, probs,
scale,
permuted_probs, permuted_probs,
permuted_scale,
num_tokens, num_tokens,
num_experts, num_experts,
hidden_size, hidden_size,
scale_hidden_dim,
inp.stride(0), inp.stride(0),
inp.stride(1), inp.stride(1),
output.stride(0), output.stride(0),
output.stride(1), output.stride(1),
probs.stride(0) if probs is not None else None, probs.stride(0) if probs is not None else None,
probs.stride(1) if probs is not None else None, probs.stride(1) if probs is not None else None,
scale.stride(0) if scale is not None else None,
scale.stride(1) if scale is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None, permuted_probs.stride(0) if permuted_probs is not None else None,
permuted_scale.stride(0) if permuted_scale is not None else None,
permuted_scale.stride(1) if permuted_scale is not None else None,
PERMUTE_PROBS=probs is not None, PERMUTE_PROBS=probs is not None,
PERMUTE_SCALE=scale is not None,
) )
return output, permuted_probs return output, permuted_scale, permuted_probs
@triton.jit @triton.jit
...@@ -232,18 +266,9 @@ def _unpermute_kernel( ...@@ -232,18 +266,9 @@ def _unpermute_kernel(
# metas # metas
WITH_MERGING_PROBS: tl.constexpr, WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
if FP8_DTYPE == "e5m2": data_type = input_ptr.dtype.element_ty
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
data_type = input_ptr.dtype.element_ty
assert FP8_DTYPE is None
compute_type = tl.float32 compute_type = tl.float32
pid = tl.program_id(0) pid = tl.program_id(0)
...@@ -257,8 +282,6 @@ def _unpermute_kernel( ...@@ -257,8 +282,6 @@ def _unpermute_kernel(
if src_row != -1: if src_row != -1:
input_off = src_row * stride_input_token + current_offset * stride_input_hidden input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask) inp = tl.load(input_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True)
inp = inp.to(compute_type) inp = inp.to(compute_type)
if WITH_MERGING_PROBS: if WITH_MERGING_PROBS:
merging_prob_off = ( merging_prob_off = (
...@@ -279,14 +302,7 @@ def _unpermute_kernel( ...@@ -279,14 +302,7 @@ def _unpermute_kernel(
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
else: else:
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0) tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0)
if FP8_DTYPE is not None: accumulator = accumulator.to(data_type)
if not WITH_MERGING_PROBS:
# Directly adding these value may cause overflow for fp8, we scale it here.
# The outside fp8_scale_inv is also scaled in the meantime.
accumulator /= num_experts
accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True)
else:
accumulator = accumulator.to(data_type)
output_off = pid * stride_output_token + current_offset * stride_output_hidden output_off = pid * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask) tl.store(output_ptr + output_off, accumulator, mask=mask)
current_start += BLOCK_SIZE current_start += BLOCK_SIZE
...@@ -315,15 +331,8 @@ def unpermute_with_mask_map( ...@@ -315,15 +331,8 @@ def unpermute_with_mask_map(
num_tokens: int, num_tokens: int,
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
fp8_dtype: TE_DType,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if fp8_dtype == TE_DType.kFloat8E5M2:
fp8_dtype = "e5m2"
elif fp8_dtype == TE_DType.kFloat8E4M3:
fp8_dtype = "e4m3"
else:
fp8_dtype = None
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if permuted_probs is not None: if permuted_probs is not None:
unpermuted_probs = torch.empty( unpermuted_probs = torch.empty(
...@@ -353,7 +362,6 @@ def unpermute_with_mask_map( ...@@ -353,7 +362,6 @@ def unpermute_with_mask_map(
unpermuted_probs.stride(1) if unpermuted_probs is not None else None, unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
WITH_MERGING_PROBS=merging_probs is not None, WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None, PERMUTE_PROBS=permuted_probs is not None,
FP8_DTYPE=fp8_dtype,
) )
return output, unpermuted_probs return output, unpermuted_probs
...@@ -383,18 +391,9 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -383,18 +391,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_grad_token, stride_merging_probs_grad_token,
stride_merging_probs_grad_expert, stride_merging_probs_grad_expert,
# metas # metas
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
if FP8_DTYPE == "e5m2": data_type = fwd_output_grad_ptr.dtype.element_ty
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
data_type = fwd_output_grad_ptr.dtype.element_ty
assert FP8_DTYPE is None
compute_type = tl.float32 compute_type = tl.float32
pid = tl.program_id(0) pid = tl.program_id(0)
...@@ -411,8 +410,6 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -411,8 +410,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
+ current_offset * stride_fwd_output_grad_hidden + current_offset * stride_fwd_output_grad_hidden
) )
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True)
inp = inp.to(compute_type) inp = inp.to(compute_type)
merging_prob_off = ( merging_prob_off = (
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
...@@ -420,8 +417,6 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -420,8 +417,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
output = inp * merging_prob output = inp * merging_prob
output = output.to(data_type) output = output.to(data_type)
if FP8_DTYPE is not None:
output = output.to(pytorch_tensor_dtype, bitcast=True)
output_off = ( output_off = (
dst_row * stride_fwd_input_grad_token dst_row * stride_fwd_input_grad_token
+ current_offset * stride_fwd_input_grad_hidden + current_offset * stride_fwd_input_grad_hidden
...@@ -432,8 +427,6 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -432,8 +427,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
) )
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
if FP8_DTYPE is not None:
fwd_input = fwd_input.to(data_type, bitcast=True)
prob_grad_accum += fwd_input.to(compute_type) * inp prob_grad_accum += fwd_input.to(compute_type) * inp
current_start += BLOCK_SIZE current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
...@@ -474,15 +467,8 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -474,15 +467,8 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
num_experts: int, num_experts: int,
num_out_tokens: int, num_out_tokens: int,
hidden_size: int, hidden_size: int,
fp8_dtype: TE_DType,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if fp8_dtype == TE_DType.kFloat8E5M2:
fp8_dtype = "e5m2"
elif fp8_dtype == TE_DType.kFloat8E4M3:
fp8_dtype = "e4m3"
else:
fp8_dtype = None
act_grad = torch.empty( act_grad = torch.empty(
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
) )
...@@ -510,7 +496,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -510,7 +496,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs.stride(1), merging_probs.stride(1),
merging_probs_grad.stride(0), merging_probs_grad.stride(0),
merging_probs_grad.stride(1), merging_probs_grad.stride(1),
fp8_dtype,
) )
return act_grad, merging_probs_grad return act_grad, merging_probs_grad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment