Unverified Commit 8ddac3df authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Remove `dtype` from args of permutation (#1145)



* remove dtype from args
* update docs with permutation ops

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
parent 4ddb0a7b
...@@ -44,3 +44,7 @@ pyTorch ...@@ -44,3 +44,7 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables .. autoapifunction:: transformer_engine.pytorch.make_graphed_callables
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
.. autoapifunction:: transformer_engine.pytorch.moe_permute
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
...@@ -23,3 +23,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py ...@@ -23,3 +23,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py
...@@ -220,9 +220,7 @@ def _test_permutation( ...@@ -220,9 +220,7 @@ def _test_permutation(
te_permute_fwd_input.requires_grad_(True) te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach()
te_permute_output, row_id_map = te_permute( te_permute_output, row_id_map = te_permute(te_permute_fwd_input, indices, num_out_tokens)
te_permute_fwd_input, te_dtype, indices, num_out_tokens
)
te_permute_output.backward(te_permute_bwd_input, retain_graph=True) te_permute_output.backward(te_permute_bwd_input, retain_graph=True)
te_probs = None te_probs = None
...@@ -233,7 +231,7 @@ def _test_permutation( ...@@ -233,7 +231,7 @@ def _test_permutation(
te_unpermute_fwd_input.requires_grad_(True) te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach()
te_unpermute_output = te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) te_unpermute_output = te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs)
te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)
################################################################################################################################### ###################################################################################################################################
...@@ -305,7 +303,7 @@ def _test_permutation( ...@@ -305,7 +303,7 @@ def _test_permutation(
lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens) lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens)
) )
t2 = perf_test_cuda_kernel( t2 = perf_test_cuda_kernel(
lambda: te_permute(te_permute_fwd_input, te_dtype, indices, num_out_tokens) lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens)
) )
print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
...@@ -333,7 +331,7 @@ def _test_permutation( ...@@ -333,7 +331,7 @@ def _test_permutation(
lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs) lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs)
) )
t2 = perf_test_cuda_kernel( t2 = perf_test_cuda_kernel(
lambda: te_unpermute(te_unpermute_fwd_input, te_dtype, row_id_map, te_probs) lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs)
) )
print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms")
......
...@@ -8,7 +8,8 @@ from typing import Tuple ...@@ -8,7 +8,8 @@ from typing import Tuple
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor from .constants import TE_DType
from .float8_tensor import Float8Tensor
__all__ = [ __all__ = [
...@@ -27,14 +28,13 @@ class _moe_permute(torch.autograd.Function): ...@@ -27,14 +28,13 @@ class _moe_permute(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
dtype: tex.DType,
indices: torch.Tensor, indices: torch.Tensor,
num_out_tokens: int, num_out_tokens: int,
max_token_num: int, max_token_num: int,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Empty input check # Empty input check
if not inp.numel(): if not inp.numel():
return inp, None return inp, torch.tensor([], device=inp.device)
# Device check # Device check
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
...@@ -43,16 +43,13 @@ class _moe_permute(torch.autograd.Function): ...@@ -43,16 +43,13 @@ class _moe_permute(torch.autograd.Function):
assert inp.size(0) == indices.size(0), "Permute not possible" assert inp.size(0) == indices.size(0), "Permute not possible"
# Data type check # Data type check
fp8 = False fp8 = isinstance(inp, Float8Tensor)
if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
fp8 = True
if fp8: if fp8:
assert isinstance( dtype = inp._fp8_dtype
inp, Float8Tensor
), "Input must be in Float8Tensor type for FP8 moe_permute."
fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv fp8_scale_inv = inp._scale_inv
inp = inp._data inp = inp._data
else:
dtype = TE_DType[inp.dtype]
if indices.dtype != torch.int32: if indices.dtype != torch.int32:
warnings.warn( warnings.warn(
f"The data type of the input `indices` of Permute is {indices.dtype}! " f"The data type of the input `indices` of Permute is {indices.dtype}! "
...@@ -78,13 +75,12 @@ class _moe_permute(torch.autograd.Function): ...@@ -78,13 +75,12 @@ class _moe_permute(torch.autograd.Function):
if fp8: if fp8:
permuted_act = Float8Tensor( permuted_act = Float8Tensor(
data=permuted_act, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv data=permuted_act, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv
) )
ctx.row_id_map = row_id_map ctx.row_id_map = row_id_map
ctx.num_tokens = indices.size(0) ctx.num_tokens = indices.size(0)
ctx.topK = indices.size(1) ctx.topK = indices.size(1)
ctx.dtype = dtype
ctx.fp8 = fp8 ctx.fp8 = fp8
return permuted_act, row_id_map return permuted_act, row_id_map
...@@ -101,30 +97,27 @@ class _moe_permute(torch.autograd.Function): ...@@ -101,30 +97,27 @@ class _moe_permute(torch.autograd.Function):
if not permuted_act_grad.is_contiguous(): if not permuted_act_grad.is_contiguous():
permuted_act_grad = permuted_act_grad.contiguous() permuted_act_grad = permuted_act_grad.contiguous()
fp8 = ctx.fp8 if ctx.fp8:
if fp8:
assert isinstance( assert isinstance(
permuted_act_grad, Float8Tensor permuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute."
fp8_dtype = permuted_act_grad._fp8_dtype dtype = permuted_act_grad._fp8_dtype
fp8_scale_inv = permuted_act_grad._scale_inv fp8_scale_inv = permuted_act_grad._scale_inv
permuted_act_grad = permuted_act_grad._data permuted_act_grad = permuted_act_grad._data
else:
row_id_map = ctx.row_id_map dtype = TE_DType[permuted_act_grad.dtype]
num_tokens = ctx.num_tokens
topK = ctx.topK
act_grad = None act_grad = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
act_grad = tex.moe_permute_bwd( act_grad = tex.moe_permute_bwd(
permuted_act_grad, ctx.dtype, row_id_map, torch.empty(0), num_tokens, topK permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK
) )
if fp8: if ctx.fp8:
act_grad = Float8Tensor( act_grad = Float8Tensor(
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv * topK data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv * ctx.topK
) )
return act_grad, None, None, None, None return act_grad, None, None, None
class _moe_unpermute(torch.autograd.Function): class _moe_unpermute(torch.autograd.Function):
...@@ -134,7 +127,6 @@ class _moe_unpermute(torch.autograd.Function): ...@@ -134,7 +127,6 @@ class _moe_unpermute(torch.autograd.Function):
def forward( def forward(
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
dtype: tex.DType,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
probs: torch.Tensor, probs: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -166,16 +158,13 @@ class _moe_unpermute(torch.autograd.Function): ...@@ -166,16 +158,13 @@ class _moe_unpermute(torch.autograd.Function):
assert row_id_map.is_cuda, "TransformerEngine needs CUDA." assert row_id_map.is_cuda, "TransformerEngine needs CUDA."
# Data type check # Data type check
fp8 = False fp8 = isinstance(inp, Float8Tensor)
if dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
fp8 = True
if fp8: if fp8:
assert isinstance( dtype = inp._fp8_dtype
inp, Float8Tensor
), "Input must be in Float8Tensor type for FP8 moe_unpermute."
fp8_dtype = inp._fp8_dtype
fp8_scale_inv = inp._scale_inv fp8_scale_inv = inp._scale_inv
inp = inp._data inp = inp._data
else:
dtype = TE_DType[inp.dtype]
if row_id_map.dtype != torch.int32: if row_id_map.dtype != torch.int32:
warnings.warn( warnings.warn(
f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! "
...@@ -187,10 +176,9 @@ class _moe_unpermute(torch.autograd.Function): ...@@ -187,10 +176,9 @@ class _moe_unpermute(torch.autograd.Function):
if fp8: if fp8:
unpermuted_output = Float8Tensor( unpermuted_output = Float8Tensor(
data=unpermuted_output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv data=unpermuted_output, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv
) )
ctx.dtype = dtype
ctx.save_for_backward(inp, row_id_map, probs) ctx.save_for_backward(inp, row_id_map, probs)
ctx.fp8 = fp8 ctx.fp8 = fp8
return unpermuted_output return unpermuted_output
...@@ -207,35 +195,33 @@ class _moe_unpermute(torch.autograd.Function): ...@@ -207,35 +195,33 @@ class _moe_unpermute(torch.autograd.Function):
if not unpermuted_act_grad.is_contiguous(): if not unpermuted_act_grad.is_contiguous():
unpermuted_act_grad = unpermuted_act_grad.contiguous() unpermuted_act_grad = unpermuted_act_grad.contiguous()
fp8 = ctx.fp8 if ctx.fp8:
if fp8:
assert isinstance( assert isinstance(
unpermuted_act_grad, Float8Tensor unpermuted_act_grad, Float8Tensor
), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute."
fp8_dtype = unpermuted_act_grad._fp8_dtype dtype = unpermuted_act_grad._fp8_dtype
fp8_scale_inv = unpermuted_act_grad._scale_inv fp8_scale_inv = unpermuted_act_grad._scale_inv
unpermuted_act_grad = unpermuted_act_grad._data unpermuted_act_grad = unpermuted_act_grad._data
else:
dtype = TE_DType[unpermuted_act_grad.dtype]
inp, row_id_map, probs = ctx.saved_tensors inp, row_id_map, probs = ctx.saved_tensors
act_grad = None act_grad = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
act_grad, prob_grad = tex.moe_unpermute_bwd( act_grad, prob_grad = tex.moe_unpermute_bwd(
unpermuted_act_grad, inp, ctx.dtype, row_id_map, probs unpermuted_act_grad, inp, dtype, row_id_map, probs
) )
if fp8: if ctx.fp8:
act_grad = Float8Tensor( act_grad = Float8Tensor(data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv)
data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv if not ctx.needs_input_grad[2]:
)
if not ctx.needs_input_grad[3]:
prob_grad = None prob_grad = None
return act_grad, None, None, prob_grad return act_grad, None, prob_grad
def moe_permute( def moe_permute(
inp: torch.Tensor, inp: torch.Tensor,
dtype: tex.DType,
indices: torch.Tensor, indices: torch.Tensor,
num_out_tokens: int = -1, num_out_tokens: int = -1,
max_token_num: int = -1, max_token_num: int = -1,
...@@ -247,8 +233,6 @@ def moe_permute( ...@@ -247,8 +233,6 @@ def moe_permute(
---------- ----------
inp: torch.Tensor inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
dtype: tex.DType
Data type of the input tensor.
indices: torch.Tensor indices: torch.Tensor
The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'. The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'.
num_out_tokens: int, default = -1 num_out_tokens: int, default = -1
...@@ -259,12 +243,11 @@ def moe_permute( ...@@ -259,12 +243,11 @@ def moe_permute(
By default, set to '-1', meaning the calculation of the size of workspace is By default, set to '-1', meaning the calculation of the size of workspace is
automatically taken over by the operator. automatically taken over by the operator.
""" """
return _moe_permute.apply(inp, dtype, indices, num_out_tokens, max_token_num) return _moe_permute.apply(inp, indices, num_out_tokens, max_token_num)
def moe_unpermute( def moe_unpermute(
inp: torch.Tensor, inp: torch.Tensor,
dtype: tex.DType,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
probs: torch.Tensor = None, probs: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -276,8 +259,6 @@ def moe_unpermute( ...@@ -276,8 +259,6 @@ def moe_unpermute(
---------- ----------
inp: torch.Tensor inp: torch.Tensor
Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted. Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
dtype: tex.DType
Data type of the input tensor.
row_id_map: torch.Tensor row_id_map: torch.Tensor
The tensor of a mapping table for sorted indices used to unpermute the tokens, The tensor of a mapping table for sorted indices used to unpermute the tokens,
which is the second output tensor of `Permute`. which is the second output tensor of `Permute`.
...@@ -286,4 +267,4 @@ def moe_unpermute( ...@@ -286,4 +267,4 @@ def moe_unpermute(
the unpermuted tokens will be merged with their respective probabilities. the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
""" """
return _moe_unpermute.apply(inp, dtype, row_id_map, probs) return _moe_unpermute.apply(inp, row_id_map, probs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment