Unverified Commit c8e7cc02 authored by Autumn1998's avatar Autumn1998 Committed by GitHub
Browse files

[MoE] Support new fp8 recipes for permute_fusion (#1649)



* add support for new recipe on permute_fusion, rm fp unpermute
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* remove fp8 from index map
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* skip unsupported tests
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatartongliu <tongliu@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent d9eb0582
......@@ -8,6 +8,7 @@ import torch
import pytest
from typing import Dict, List
from transformer_engine.common import recipe
from transformer_engine.pytorch import (
moe_permute as te_permute,
moe_permute_with_probs as te_permute_with_probs,
......@@ -17,9 +18,14 @@ from transformer_engine.pytorch import (
)
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex
import copy
seed = 1234
torch.manual_seed(seed)
......@@ -234,7 +240,6 @@ def _test_permutation_index_map(
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
)
fp8 = False
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
......@@ -242,48 +247,12 @@ def _test_permutation_index_map(
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else:
pytest.skip("Invalid dtype.")
if fp8:
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
permute_bwd_input = torch.rand(
size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
unpermute_bwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
_permute_fwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_permute_bwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_unpermute_bwd_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input)
permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input)
unpermute_bwd_input = _unpermute_bwd_quantizer(unpermute_bwd_input)
pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16)
pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16)
pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16)
else:
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_fwd_input.requires_grad_(True)
......@@ -323,9 +292,9 @@ def _test_permutation_index_map(
# TE Permutation
#
###################################################################################################################################
te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach()
te_permute_fwd_input = pytorch_permute_fwd_input.detach()
te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach()
te_permute_bwd_input = pytorch_permute_bwd_input.detach()
te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, indices, num_out_tokens, map_type="index"
......@@ -338,7 +307,7 @@ def _test_permutation_index_map(
te_probs.requires_grad_(True)
te_unpermute_fwd_input = te_permute_output.detach()
te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()
te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach()
te_unpermute_output = te_unpermute(
te_unpermute_fwd_input, row_id_map, te_probs, map_type="index"
......@@ -352,16 +321,10 @@ def _test_permutation_index_map(
###################################################################################################################################
tols = dtype_tols(te_dtype)
if fp8:
te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32)
te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32)
te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32)
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32)
else:
te_permute_output_ = te_permute_output.float()
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
te_permute_output_ = te_permute_output.float()
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
torch.testing.assert_close(
pytorch_permute_output.float(),
......@@ -487,7 +450,6 @@ def _test_permutation_mask_map(
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
)
fp8 = False
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
......@@ -495,49 +457,12 @@ def _test_permutation_mask_map(
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else:
pytest.skip("Invalid dtype.")
if fp8:
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
permute_bwd_input = torch.rand(
size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
unpermute_bwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
_permute_fwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_permute_bwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_unpermute_bwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input)
permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input)
unpermute_bwd_input = _unpermute_bwd_input_quantizer(unpermute_bwd_input)
pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16)
pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16)
pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16)
else:
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_fwd_input.requires_grad_(True)
......@@ -553,10 +478,7 @@ def _test_permutation_mask_map(
probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
if fp8:
probs = probs.to(torch.float16)
else:
probs = probs.to(dtype)
probs = probs.to(dtype)
probs.requires_grad_(True)
###################################################################################################################################
......@@ -582,9 +504,9 @@ def _test_permutation_mask_map(
# TE Permutation
#
###################################################################################################################################
te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach()
te_permute_fwd_input = pytorch_permute_fwd_input.detach()
te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach()
te_permute_bwd_input = pytorch_permute_bwd_input.detach()
te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask"
......@@ -597,7 +519,7 @@ def _test_permutation_mask_map(
te_probs.requires_grad_(True)
te_unpermute_fwd_input = te_permute_output.detach()
te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()
te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach()
te_unpermute_output = te_unpermute(
te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask"
......@@ -611,16 +533,10 @@ def _test_permutation_mask_map(
###################################################################################################################################
tols = dtype_tols(te_dtype)
if fp8:
te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32)
te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32)
te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32)
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32)
else:
te_permute_output_ = te_permute_output.float()
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
te_permute_output_ = te_permute_output.float()
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
torch.testing.assert_close(
pytorch_permute_output.float(),
......@@ -730,6 +646,118 @@ def _test_permutation_mask_map(
print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
def _test_permutation_mask_map_fp8(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
recipe,
):
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
if num_out_tokens == None:
num_out_tokens = num_tokens * topK
if recipe.delayed():
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
elif recipe.float8_current_scaling():
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=te_dtype,
device=torch.device("cuda"),
columnwise=False,
)
elif recipe.float8_block_scaling():
quantizer = Float8BlockQuantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=False,
amax_epsilon=0.0,
force_pow_2_scales=True, # Fp8 sub-channel a2a requires e8 scales
block_scaling_dim=1, # 1x128 scaling
)
elif recipe.mxfp8():
quantizer = MXFP8Quantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=False,
)
else:
raise ValueError("Unsupported FP8 recipe")
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
# Make an empty fp8 tensor
permute_fwd_input_fp8 = quantizer.make_empty(
permute_fwd_input.shape,
dtype=permute_fwd_input.dtype,
device=permute_fwd_input.device,
)
# quantize the tensor
quantizer.update_quantized(permute_fwd_input, permute_fwd_input_fp8)
if recipe.float8_block_scaling():
pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._rowwise_data)
pytorch_permute_fwd_scale_input = copy.deepcopy(
permute_fwd_input_fp8._rowwise_scale_inv.T.contiguous()
)
elif recipe.mxfp8():
pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._rowwise_data)
pytorch_permute_fwd_scale_input = copy.deepcopy(
permute_fwd_input_fp8._rowwise_scale_inv.contiguous()
)
else:
pytorch_permute_fwd_input = copy.deepcopy(permute_fwd_input_fp8._data)
pytorch_permute_fwd_scale_input = None
_tmp_tensor = torch.zeros((num_tokens * num_expert,))
_tmp_tensor[: int(num_out_tokens)] = 1.0
_tmp_idx = torch.randperm(num_tokens * num_expert)
routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda()
# PyTorch Permutaion
pytorch_permute_output, _ = pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map)
if pytorch_permute_fwd_scale_input is not None:
pytorch_permute_scale_output, _ = pytorch_permute_mask_map(
pytorch_permute_fwd_scale_input, routing_map
)
# TE Permutation
permute_output, _ = te_permute(
permute_fwd_input_fp8, routing_map, num_out_tokens=num_out_tokens, map_type="mask"
)
if recipe.float8_block_scaling():
te_permute_output = permute_output._rowwise_data
te_permute_scale_output = permute_output._rowwise_scale_inv.T.contiguous()
elif recipe.mxfp8():
te_permute_output = permute_output._rowwise_data
te_permute_scale_output = permute_output._rowwise_scale_inv.contiguous()
else:
te_permute_output = permute_output._data
te_permute_scale_output = None
# check the permute output
torch.testing.assert_close(
pytorch_permute_output,
te_permute_output,
atol=0,
rtol=0,
)
if recipe.float8_block_scaling() or recipe.mxfp8():
torch.testing.assert_close(
pytorch_permute_scale_output,
te_permute_scale_output,
atol=0,
rtol=0,
)
def _test_moe_chunk_sort(
te_dtype,
num_tokens,
......@@ -743,7 +771,6 @@ def _test_moe_chunk_sort(
f" token:{num_tokens} hidden_size:{hidden_size} num_expert:{num_expert} tp_size:{tp_size} {te_dtype}"
)
fp8 = False
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
......@@ -751,34 +778,11 @@ def _test_moe_chunk_sort(
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else:
pytest.skip("Invalid dtype.")
if fp8:
fwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda")
bwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda")
_fwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_bwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
fwd_input = _fwd_input_quantizer.quantize(fwd_input)
bwd_input = _bwd_input_quantizer.quantize(bwd_input)
pytorch_fwd_input = fwd_input.dequantize(dtype=torch.float16)
pytorch_bwd_input = bwd_input.dequantize(dtype=torch.float16)
else:
pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_fwd_input.requires_grad_(True)
......@@ -806,9 +810,9 @@ def _test_moe_chunk_sort(
# TE Permutation
#
###################################################################################################################################
te_fwd_input = fwd_input if fp8 else pytorch_fwd_input.detach()
te_fwd_input = pytorch_fwd_input.detach()
te_fwd_input.requires_grad_(True)
te_bwd_input = bwd_input if fp8 else pytorch_bwd_input.detach()
te_bwd_input = pytorch_bwd_input.detach()
te_output = te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda)
te_output.backward(te_bwd_input, retain_graph=True)
......@@ -820,12 +824,8 @@ def _test_moe_chunk_sort(
###################################################################################################################################
tols = dtype_tols(te_dtype)
if fp8:
te_output_ = te_output.dequantize(dtype=torch.float32)
te_fwd_input_grad = te_fwd_input.grad.dequantize(dtype=torch.float32)
else:
te_output_ = te_output.float()
te_fwd_input_grad = te_fwd_input.grad.float()
te_output_ = te_output.float()
te_fwd_input_grad = te_fwd_input.grad.float()
torch.testing.assert_close(
pytorch_output.float(),
......@@ -899,7 +899,6 @@ def _test_permutation_mask_map_alongside_probs(
f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}"
)
fp8 = False
# Convert TE dtypes to PyTorch dtypes
if te_dtype == tex.DType.kFloat32:
dtype = torch.float32
......@@ -907,38 +906,11 @@ def _test_permutation_mask_map_alongside_probs(
dtype = torch.float16
elif te_dtype == tex.DType.kBFloat16:
dtype = torch.bfloat16
elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3):
dtype = torch.uint8
fp8 = True
else:
pytest.skip("Invalid dtype.")
if fp8:
permute_fwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
unpermute_bwd_input = torch.rand(
size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda"
)
_permute_fwd_input_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
_unpermute_bwd_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
permute_fwd_input = _permute_fwd_input_quantizer.quantize(permute_fwd_input)
unpermute_bwd_input = _unpermute_bwd_quantizer.quantize(unpermute_bwd_input)
pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16)
pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16)
else:
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda()
pytorch_permute_fwd_input.requires_grad_(True)
......@@ -952,10 +924,7 @@ def _test_permutation_mask_map_alongside_probs(
probs = torch.rand(num_tokens, num_expert).cuda() * routing_map
row_sums = probs.sum(dim=1, keepdim=True)
probs = probs / row_sums
if fp8:
probs = probs.to(torch.float16)
else:
probs = probs.to(dtype)
probs = probs.to(dtype)
probs.requires_grad_(True)
split_sizes = [0] * (num_expert * tp_size)
......@@ -1006,13 +975,12 @@ def _test_permutation_mask_map_alongside_probs(
# TE Permutation
#
###################################################################################################################################
te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach()
te_permute_fwd_input = pytorch_permute_fwd_input.detach()
te_permute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()
te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach()
te_probs = probs.detach()
te_probs.requires_grad_(True)
print(te_probs.shape)
te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs(
te_permute_fwd_input,
......@@ -1020,27 +988,14 @@ def _test_permutation_mask_map_alongside_probs(
routing_map,
num_out_tokens=num_out_tokens,
)
print(te_permuted_probs.shape)
te_permute_output, te_permuted_probs = te_sort_chunks_by_index_with_probs(
te_permute_output, te_permuted_probs, split_sizes_cuda, sorted_idxs_cuda
)
if fp8:
_permute_output_quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda().squeeze(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=te_dtype,
)
te_permute_output = te_permute_output.dequantize(dtype=torch.float32)
te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1)
te_permute_output = _permute_output_quantizer.quantize(te_permute_output)
else:
te_permute_output_dtype = te_permute_output.dtype
print(te_permute_output.shape)
print(te_permuted_probs.shape)
te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1)
te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype)
te_permute_output_dtype = te_permute_output.dtype
te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1)
te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype)
te_permute_output = te_sort_chunks_by_index(
te_permute_output, split_sizes_2_cuda, sorted_idxs_2_cuda
......@@ -1058,13 +1013,8 @@ def _test_permutation_mask_map_alongside_probs(
tols = dtype_tols(te_dtype)
if fp8:
# backward of dequantize is in high precision
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32)
else:
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
torch.testing.assert_close(
pytorch_unpermute_output.float(),
......@@ -1228,6 +1178,16 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype):
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
......@@ -1237,36 +1197,7 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
def test_permutation_index_map_fp8(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
):
with_probs = True
BENCHMARK = False
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
@pytest.mark.parametrize("recipe", fp8_recipes)
def test_permutation_mask_map_fp8(
te_dtype,
num_tokens,
......@@ -1274,47 +1205,21 @@ def test_permutation_mask_map_fp8(
hidden_size,
topK,
num_out_tokens,
recipe,
):
with_probs = True
BENCHMARK = False
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
@pytest.mark.parametrize("tp_size", [1, 2, 8])
def test_permutation_mask_map_alongside_probs_fp8(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
tp_size,
):
_test_permutation_mask_map_alongside_probs(
_test_permutation_mask_map_fp8(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
tp_size=tp_size,
recipe=recipe,
)
......@@ -1415,11 +1320,9 @@ def test_permutation_single_case():
# te_dtype = tex.DType.kFloat32
# te_dtype = tex.DType.kFloat16
# te_dtype = tex.DType.kBFloat16
te_dtype = tex.DType.kFloat8E5M2
# te_dtype = tex.DType.kFloat8E4M3
te_dtype = tex.DType.kBFloat16
num_tokens = 10
num_tokens = 12
num_expert = 4
hidden_size = 16
topK = 2
......
......@@ -333,7 +333,7 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
const transformer_engine::Tensor *input_fwd_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input_fwd);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T,
nvte_permute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr),
reinterpret_cast<T *>(output_cu->data.dptr),
......@@ -359,7 +359,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
const transformer_engine::Tensor *prob_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T,
nvte_unpermute_launcher(reinterpret_cast<const T *>(input_cu->data.dptr),
reinterpret_cast<T *>(output_cu->data.dptr),
......
......@@ -52,18 +52,11 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
sorted_indices_ptr, row_id_ptr, sorted_row_id_ptr,
num_tokens * topK);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input.scalar_type();
// Output buffer alloc
num_out_tokens = (num_out_tokens > 0) ? num_out_tokens : num_tokens * topK;
at::Tensor permuted_output = torch::empty(
{num_out_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
at::Tensor permuted_output =
torch::empty({num_out_tokens, num_cols},
torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false));
at::Tensor row_id_map = torch::empty(
{num_tokens * topK}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
......@@ -100,17 +93,10 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
using namespace transformer_engine::pytorch;
int num_cols = input.size(1);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input.scalar_type();
// Output buffer alloc
at::Tensor unpermuted_output = torch::empty(
{num_tokens, num_cols}, torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
at::Tensor unpermuted_output =
torch::empty({num_tokens, num_cols},
torch::dtype(input.scalar_type()).device(torch::kCUDA).requires_grad(false));
auto stream = at::cuda::getCurrentCUDAStream().stream();
......@@ -136,17 +122,10 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0);
int num_cols = input_bwd.size(1);
// Activations type
at::ScalarType _st;
if (dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2)
_st = at::ScalarType::Byte;
else
_st = input_bwd.scalar_type();
// Output buffer alloc
at::Tensor act_grad = torch::empty({input_fwd.size(0), num_cols},
torch::dtype(_st).device(torch::kCUDA).requires_grad(false));
at::Tensor act_grad =
torch::empty({input_fwd.size(0), num_cols},
torch::dtype(input_bwd.scalar_type()).device(torch::kCUDA).requires_grad(false));
at::Tensor prob_grad = torch::empty(
{num_tokens, topK}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
......
......@@ -4,14 +4,16 @@
"""MoE Permutaion API"""
import warnings
from typing import Tuple
from typing import Optional, Tuple
import torch
import transformer_engine_torch as tex
import transformer_engine.pytorch.triton.permutation as triton_permutation
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
__all__ = [
"moe_permute",
......@@ -46,17 +48,7 @@ class _moe_permute_index_map(torch.autograd.Function):
assert inp.size(0) == index.size(0), "Permute not possible"
# Data type check
fp8 = isinstance(inp, Float8Tensor)
if fp8:
assert (
inp._quantizer.scale.ndim == 0
), "Only one factor scaling per tensor (Delayed Scaling) supported by moe_permute."
dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
dtype = TE_DType[inp.dtype]
dtype = TE_DType[inp.dtype]
if index.dtype != torch.int32:
warnings.warn(
f"The data type of the input `index` of Permute is {index.dtype}! "
......@@ -80,19 +72,9 @@ class _moe_permute_index_map(torch.autograd.Function):
_moe_permute_index_map.max_expanded_token_num,
)
if fp8:
permuted_act = Float8Tensor(
data=permuted_act,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv,
shape=permuted_act.shape,
dtype=fake_dtype,
)
ctx.row_id_map = row_id_map
ctx.num_tokens = index.size(0)
ctx.topK = index.size(1)
ctx.fp8 = fp8
return permuted_act, row_id_map
@staticmethod
......@@ -109,30 +91,12 @@ class _moe_permute_index_map(torch.autograd.Function):
if not permuted_act_grad.is_contiguous():
permuted_act_grad = permuted_act_grad.contiguous()
if ctx.fp8:
assert isinstance(
permuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
fake_dtype = permuted_act_grad.dtype
permuted_act_grad = permuted_act_grad._data
else:
dtype = TE_DType[permuted_act_grad.dtype]
dtype = TE_DType[permuted_act_grad.dtype]
act_grad = None
if ctx.needs_input_grad[0]:
act_grad = tex.moe_permute_bwd(
permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
)
if ctx.fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv * ctx.topK,
shape=act_grad.shape,
dtype=fake_dtype,
)
return act_grad, None, None, None
......@@ -176,14 +140,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
# Data type check
fp8 = isinstance(inp, Float8Tensor)
if fp8:
dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
dtype = TE_DType[inp.dtype]
dtype = TE_DType[inp.dtype]
if row_id_map.dtype != torch.int32:
warnings.warn(
f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
......@@ -193,17 +150,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK)
if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv,
shape=unpermuted_output.shape,
dtype=fake_dtype,
)
ctx.save_for_backward(inp, row_id_map, probs)
ctx.fp8 = fp8
return unpermuted_output
@staticmethod
......@@ -219,17 +166,7 @@ class _moe_unpermute_index_map(torch.autograd.Function):
if not unpermuted_act_grad.is_contiguous():
unpermuted_act_grad = unpermuted_act_grad.contiguous()
if ctx.fp8:
assert isinstance(
unpermuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
fake_dtype = unpermuted_act_grad.dtype
unpermuted_act_grad = unpermuted_act_grad._data
else:
dtype = TE_DType[unpermuted_act_grad.dtype]
dtype = TE_DType[unpermuted_act_grad.dtype]
inp, row_id_map, probs = ctx.saved_tensors
act_grad = None
......@@ -238,14 +175,6 @@ class _moe_unpermute_index_map(torch.autograd.Function):
act_grad, prob_grad = tex.moe_unpermute_bwd(
unpermuted_act_grad, inp, dtype, row_id_map, probs
)
if ctx.fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=dtype,
fp8_scale_inv=fp8_scale_inv,
shape=act_grad.shape,
dtype=fake_dtype,
)
if not ctx.needs_input_grad[2]:
prob_grad = None
......@@ -282,29 +211,86 @@ class _moe_permute_mask_map(torch.autograd.Function):
row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)
fp8 = isinstance(inp, Float8Tensor)
fp8 = isinstance(inp, QuantizedTensor)
per_tensor_recipe = isinstance(inp, Float8Tensor)
blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor)
mxfp8_recipe = isinstance(inp, MXFP8Tensor)
if fp8:
fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
output, permuted_probs = triton_permutation.permute_with_mask_map(
# blockwise scaling
if blockwise_recipe:
fp8_scale = inp._rowwise_scale_inv.T.contiguous()
scale_hidden_dim = fp8_scale.shape[1]
assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
inp = inp._rowwise_data
# mxfp8 scaling
elif mxfp8_recipe:
fp8_scale = inp._rowwise_scale_inv.contiguous()
scale_hidden_dim = fp8_scale.shape[1]
assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
inp = inp._rowwise_data
# per-tensor scaling
elif per_tensor_recipe:
# Kernel does not need scale in per-tensor scaling
fp8_scale = None
scale_hidden_dim = None
fp8_scale_inv = inp._scale_inv
inp = inp._data
else:
raise ValueError("Unsupported FP8 recipe")
else:
fp8_scale = None
fp8_dtype = None
scale_hidden_dim = None
output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map(
inp,
row_id_map,
probs,
fp8_scale,
num_tokens,
num_experts,
num_out_tokens,
hidden_size,
scale_hidden_dim,
)
if fp8:
output = Float8Tensor(
data=output,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=output.shape,
dtype=fake_dtype,
)
if per_tensor_recipe:
output = Float8Tensor(
data=output,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=output.shape,
dtype=fake_dtype,
)
elif blockwise_recipe:
output = Float8BlockwiseQTensor(
shape=output.shape,
dtype=fake_dtype,
rowwise_data=output,
rowwise_scale_inv=permuted_scale.T.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
quantizer=None,
is_2D_scaled=False,
requires_grad=output.requires_grad,
)
elif mxfp8_recipe:
output = MXFP8Tensor(
shape=output.shape,
dtype=fake_dtype,
fp8_dtype=fp8_dtype,
rowwise_data=output,
rowwise_scale_inv=permuted_scale.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
quantizer=None,
requires_grad=output.requires_grad,
)
ctx.save_for_backward(row_id_map)
ctx.num_experts = num_experts
......@@ -327,14 +313,9 @@ class _moe_permute_mask_map(torch.autograd.Function):
probs_grad = None
if ctx.needs_input_grad[0]:
(row_id_map,) = ctx.saved_tensors
fp8 = isinstance(permuted_act_grad, Float8Tensor)
if fp8:
fp8_dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv
fake_dtype = permuted_act_grad.dtype
permuted_act_grad = permuted_act_grad._data
else:
fp8_dtype = None
assert not isinstance(
permuted_act_grad, QuantizedTensor
), "The backward of moe_permute does not support FP8."
act_grad, probs_grad = triton_permutation.unpermute_with_mask_map(
permuted_act_grad,
row_id_map,
......@@ -343,16 +324,7 @@ class _moe_permute_mask_map(torch.autograd.Function):
ctx.num_tokens,
ctx.num_experts,
ctx.hidden_size,
fp8_dtype,
)
if fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv * ctx.num_experts,
shape=act_grad.shape,
dtype=fake_dtype,
)
if not ctx.needs_input_grad[3]:
probs_grad = None
return act_grad, None, None, probs_grad
......@@ -366,8 +338,8 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
ctx,
inp: torch.Tensor,
row_id_map: torch.Tensor,
merging_probs: torch.Tensor,
restore_shape: torch.Size,
merging_probs: Optional[torch.Tensor],
restore_shape: Optional[torch.Size],
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if not inp.numel():
......@@ -387,17 +359,9 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
fp8 = isinstance(inp, Float8Tensor)
if fp8:
fp8_dtype = inp._fp8_dtype
if not with_probs:
fp8_scale_inv = inp._scale_inv * num_experts
else:
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
else:
fp8_dtype = None
assert not isinstance(
inp, QuantizedTensor
), "The forward of moe_unpermute does not support FP8."
unpermuted_output, _ = triton_permutation.unpermute_with_mask_map(
inp,
row_id_map,
......@@ -406,16 +370,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
num_tokens,
num_experts,
hidden_size,
fp8_dtype=fp8_dtype,
)
if fp8:
unpermuted_output = Float8Tensor(
data=unpermuted_output,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=unpermuted_output.shape,
dtype=fake_dtype,
)
if with_probs:
ctx.save_for_backward(inp, row_id_map, merging_probs)
......@@ -442,16 +397,44 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
else:
(row_id_map,) = ctx.saved_tensors
fp8 = isinstance(unpermuted_act_grad, Float8Tensor)
fp8 = isinstance(unpermuted_act_grad, QuantizedTensor)
per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor)
blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor)
mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor)
if fp8:
fp8_dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv
fake_dtype = unpermuted_act_grad.dtype
unpermuted_act_grad = unpermuted_act_grad._data
# per-tensor scaling
if per_tensor_recipe:
# Kernel does not need scale in per-tensor scaling
fp8_scale = None
scale_hidden_dim = None
fp8_scale_inv = unpermuted_act_grad._scale_inv
unpermuted_act_grad = unpermuted_act_grad._data
# blockwise scaling
elif blockwise_recipe:
fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous()
unpermuted_act_grad = unpermuted_act_grad._rowwise_data
scale_hidden_dim = fp8_scale.shape[1]
assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
# mxfp8 scaling
elif mxfp8_recipe:
fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous()
unpermuted_act_grad = unpermuted_act_grad._rowwise_data
scale_hidden_dim = fp8_scale.shape[1]
assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch"
else:
raise ValueError("Unsupported FP8 recipe")
else:
scale_hidden_dim = None
fp8_dtype = None
fp8_scale = None
if ctx.with_probs:
assert (
not fp8
), "The backward of moe_unpermute with merging probs does not support FP8."
act_grad, probs_grad = (
triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs(
unpermuted_act_grad,
......@@ -462,28 +445,55 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
ctx.num_experts,
ctx.num_permuted_tokens,
ctx.hidden_size,
fp8_dtype,
)
)
else:
act_grad, _ = triton_permutation.permute_with_mask_map(
act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map(
unpermuted_act_grad,
row_id_map,
None,
fp8_scale,
ctx.num_tokens,
ctx.num_experts,
ctx.num_permuted_tokens,
ctx.hidden_size,
scale_hidden_dim,
)
if fp8:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=act_grad.shape,
dtype=fake_dtype,
)
if per_tensor_recipe:
act_grad = Float8Tensor(
data=act_grad,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
shape=act_grad.shape,
dtype=fake_dtype,
)
elif blockwise_recipe:
act_grad = Float8BlockwiseQTensor(
shape=act_grad.shape,
dtype=fake_dtype,
rowwise_data=act_grad,
rowwise_scale_inv=permuted_scale.T.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
fp8_dtype=fp8_dtype,
quantizer=None,
is_2D_scaled=False,
requires_grad=act_grad.requires_grad,
)
elif mxfp8_recipe:
act_grad = MXFP8Tensor(
shape=act_grad.shape,
dtype=fake_dtype,
fp8_dtype=fp8_dtype,
rowwise_data=act_grad,
rowwise_scale_inv=permuted_scale.contiguous(),
columnwise_data=None,
columnwise_scale_inv=None,
quantizer=None,
requires_grad=act_grad.requires_grad,
)
if not ctx.needs_input_grad[2]:
probs_grad = None
......@@ -568,10 +578,10 @@ def moe_permute_with_probs(
def moe_unpermute(
inp: torch.Tensor,
row_id_map: torch.Tensor,
merging_probs: torch.Tensor = None,
restore_shape: torch.Tensor = None,
merging_probs: Optional[torch.Tensor] = None,
restore_shape: Optional[torch.Size] = None,
map_type: str = "mask",
probs: torch.Tensor = None,
probs: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Unpermute a tensor with permuted tokens, and optionally merge the tokens with their
......@@ -588,7 +598,7 @@ def moe_unpermute(
The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
restore_shape: torch.Tensor
restore_shape: torch.Size, default = None
The output shape after the unpermute operation.
map_type: str, default = 'mask'
Type of the routing map tensor. Should be the same as the value passed to moe_permute.
......
......@@ -10,8 +10,6 @@ import torch
import triton
import triton.language as tl
from transformer_engine_torch import DType as TE_DType
@triton.jit
def _row_id_map_pass_1_kernel(
......@@ -116,11 +114,14 @@ def _permute_kernel(
output_ptr,
row_id_map_ptr,
probs_ptr,
scale_ptr,
permuted_probs_ptr,
permuted_scale_ptr,
# sizes
num_tokens,
num_experts,
hidden_size,
scale_hidden_dim,
# strides
stride_input_token,
stride_input_hidden,
......@@ -128,9 +129,14 @@ def _permute_kernel(
stride_output_hidden,
stride_probs_token,
stride_probs_expert,
stride_scale_token,
stride_scale_hidden,
stride_permuted_probs_token,
stride_permuted_scale_token,
stride_permuted_scale_hidden,
# metas
PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
......@@ -140,11 +146,21 @@ def _permute_kernel(
mask = cur_off < hidden_size
input_off = pid * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim
scale_off = pid * stride_scale_token + cur_off * stride_scale_hidden
scale = tl.load(scale_ptr + scale_off, mask=mask_scale)
for expert_idx in range(num_experts):
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid)
if dst_row != -1:
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
tl.store(output_ptr + output_off, inp, mask=mask)
if PERMUTE_SCALE:
permuted_scale_off = (
dst_row * stride_permuted_scale_token
+ cur_off * stride_permuted_scale_hidden
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS:
if cur_pos == 0:
prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert
......@@ -173,10 +189,12 @@ def permute_with_mask_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
probs: torch.Tensor,
scale: torch.Tensor,
num_tokens: int,
num_experts: int,
num_out_tokens: int,
hidden_size: int,
scale_hidden_dim: int,
):
# pylint: disable=missing-function-docstring
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
......@@ -184,26 +202,42 @@ def permute_with_mask_map(
permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
else:
permuted_probs = None
if scale is not None:
permuted_scale = torch.empty(
(num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda"
)
else:
permuted_scale = None
grid = (num_tokens,)
_permute_kernel[grid](
inp,
output,
row_id_map,
probs,
scale,
permuted_probs,
permuted_scale,
num_tokens,
num_experts,
hidden_size,
scale_hidden_dim,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
probs.stride(0) if probs is not None else None,
probs.stride(1) if probs is not None else None,
scale.stride(0) if scale is not None else None,
scale.stride(1) if scale is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None,
permuted_scale.stride(0) if permuted_scale is not None else None,
permuted_scale.stride(1) if permuted_scale is not None else None,
PERMUTE_PROBS=probs is not None,
PERMUTE_SCALE=scale is not None,
)
return output, permuted_probs
return output, permuted_scale, permuted_probs
@triton.jit
......@@ -232,18 +266,9 @@ def _unpermute_kernel(
# metas
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
if FP8_DTYPE == "e5m2":
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
data_type = input_ptr.dtype.element_ty
assert FP8_DTYPE is None
data_type = input_ptr.dtype.element_ty
compute_type = tl.float32
pid = tl.program_id(0)
......@@ -257,8 +282,6 @@ def _unpermute_kernel(
if src_row != -1:
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
merging_prob_off = (
......@@ -279,14 +302,7 @@ def _unpermute_kernel(
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
else:
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0)
if FP8_DTYPE is not None:
if not WITH_MERGING_PROBS:
# Directly adding these value may cause overflow for fp8, we scale it here.
# The outside fp8_scale_inv is also scaled in the meantime.
accumulator /= num_experts
accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True)
else:
accumulator = accumulator.to(data_type)
accumulator = accumulator.to(data_type)
output_off = pid * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask)
current_start += BLOCK_SIZE
......@@ -315,15 +331,8 @@ def unpermute_with_mask_map(
num_tokens: int,
num_experts: int,
hidden_size: int,
fp8_dtype: TE_DType,
):
# pylint: disable=missing-function-docstring
if fp8_dtype == TE_DType.kFloat8E5M2:
fp8_dtype = "e5m2"
elif fp8_dtype == TE_DType.kFloat8E4M3:
fp8_dtype = "e4m3"
else:
fp8_dtype = None
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if permuted_probs is not None:
unpermuted_probs = torch.empty(
......@@ -353,7 +362,6 @@ def unpermute_with_mask_map(
unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None,
FP8_DTYPE=fp8_dtype,
)
return output, unpermuted_probs
......@@ -383,18 +391,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_grad_token,
stride_merging_probs_grad_expert,
# metas
FP8_DTYPE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
if FP8_DTYPE == "e5m2":
data_type = tl.float8e5
pytorch_tensor_dtype = tl.uint8
elif FP8_DTYPE == "e4m3":
data_type = tl.float8e4nv
pytorch_tensor_dtype = tl.uint8
else:
data_type = fwd_output_grad_ptr.dtype.element_ty
assert FP8_DTYPE is None
data_type = fwd_output_grad_ptr.dtype.element_ty
compute_type = tl.float32
pid = tl.program_id(0)
......@@ -411,8 +410,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
+ current_offset * stride_fwd_output_grad_hidden
)
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
if FP8_DTYPE is not None:
inp = inp.to(data_type, bitcast=True)
inp = inp.to(compute_type)
merging_prob_off = (
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
......@@ -420,8 +417,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
output = inp * merging_prob
output = output.to(data_type)
if FP8_DTYPE is not None:
output = output.to(pytorch_tensor_dtype, bitcast=True)
output_off = (
dst_row * stride_fwd_input_grad_token
+ current_offset * stride_fwd_input_grad_hidden
......@@ -432,8 +427,6 @@ def _unpermute_bwd_with_merging_probs_kernel(
dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
)
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
if FP8_DTYPE is not None:
fwd_input = fwd_input.to(data_type, bitcast=True)
prob_grad_accum += fwd_input.to(compute_type) * inp
current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
......@@ -474,15 +467,8 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
num_experts: int,
num_out_tokens: int,
hidden_size: int,
fp8_dtype: TE_DType,
):
# pylint: disable=missing-function-docstring
if fp8_dtype == TE_DType.kFloat8E5M2:
fp8_dtype = "e5m2"
elif fp8_dtype == TE_DType.kFloat8E4M3:
fp8_dtype = "e4m3"
else:
fp8_dtype = None
act_grad = torch.empty(
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
)
......@@ -510,7 +496,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs.stride(1),
merging_probs_grad.stride(0),
merging_probs_grad.stride(1),
fp8_dtype,
)
return act_grad, merging_probs_grad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment