Unverified Commit 9bf4175f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Deprecate old `float8_tensor.py` (#2250)



Deprecate old float8_tensor.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e37e33e1
...@@ -29,7 +29,7 @@ from transformer_engine.pytorch.tensor.quantized_tensor import ( ...@@ -29,7 +29,7 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
TE_DType, TE_DType,
QKVLayouts, QKVLayouts,
......
...@@ -20,7 +20,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -20,7 +20,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
FusedAttnBackend, FusedAttnBackend,
) )
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage
from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.jit import jit_fuser
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
......
...@@ -30,7 +30,7 @@ from transformer_engine.pytorch.fp8 import ( ...@@ -30,7 +30,7 @@ from transformer_engine.pytorch.fp8 import (
Float8CurrentScalingRecipeState, Float8CurrentScalingRecipeState,
Float8BlockScalingRecipeState, Float8BlockScalingRecipeState,
) )
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
......
...@@ -35,8 +35,8 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import ( ...@@ -35,8 +35,8 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_DP, META_DP,
) )
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import ( from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
Float8Quantizer, Float8Quantizer,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
......
...@@ -4,6 +4,16 @@ ...@@ -4,6 +4,16 @@
"""Tensor class with FP8 data""" """Tensor class with FP8 data"""
import warnings
from .tensor.float8_tensor import Float8Tensor from .tensor.float8_tensor import Float8Tensor
warnings.warn(
"transformer_engine.pytorch.float8_tensor is deprecated and will be removed"
" in a future release. Float8Tensor should be imported directly through "
"`from transformer_engine.pytorch import Float8Tensor`",
DeprecationWarning,
stacklevel=2,
)
__all__ = ["Float8Tensor"] __all__ = ["Float8Tensor"]
...@@ -184,7 +184,7 @@ def combine_tensors( ...@@ -184,7 +184,7 @@ def combine_tensors(
num_tensors = len(tensors) num_tensors = len(tensors)
new_shape = list(tensors[0].shape) new_shape = list(tensors[0].shape)
new_shape.insert(dim, num_tensors) new_shape.insert(dim, num_tensors)
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
if isinstance(tensors[0], Float8Tensor): if isinstance(tensors[0], Float8Tensor):
new_stride = list(tensors[0]._data.stride()) new_stride = list(tensors[0]._data.stride())
...@@ -224,7 +224,7 @@ class SplitAlongDim(torch.autograd.Function): ...@@ -224,7 +224,7 @@ class SplitAlongDim(torch.autograd.Function):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
ctx.split_dim = split_dim ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections ctx.split_size_or_sections = split_size_or_sections
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import ( from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import (
Float8TensorStorage, Float8TensorStorage,
) )
...@@ -278,7 +278,7 @@ class SplitAlongDim(torch.autograd.Function): ...@@ -278,7 +278,7 @@ class SplitAlongDim(torch.autograd.Function):
split_sizes = [ctx.split_size_or_sections] * len(grad_outputs) split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
dims = len(grad_outputs[0].shape) dims = len(grad_outputs[0].shape)
split_dim = (ctx.split_dim + dims) % dims split_dim = (ctx.split_dim + dims) % dims
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
if isinstance(grad_outputs[0], Float8Tensor): if isinstance(grad_outputs[0], Float8Tensor):
noop_ok = True noop_ok = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment