Unverified Commit 9bf4175f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Deprecate old `float8_tensor.py` (#2250)



Deprecate old float8_tensor.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e37e33e1
......@@ -29,7 +29,7 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.constants import (
TE_DType,
QKVLayouts,
......
......@@ -20,7 +20,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
FusedAttnBackend,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage
from transformer_engine.pytorch.jit import jit_fuser
from transformer_engine.pytorch.constants import (
......
......@@ -30,7 +30,7 @@ from transformer_engine.pytorch.fp8 import (
Float8CurrentScalingRecipeState,
Float8BlockScalingRecipeState,
)
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.constants import (
......
......@@ -35,8 +35,8 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_DP,
)
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
......
......@@ -10,7 +10,7 @@ import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
......
......@@ -4,6 +4,16 @@
"""Tensor class with FP8 data"""
import warnings
from .tensor.float8_tensor import Float8Tensor
warnings.warn(
"transformer_engine.pytorch.float8_tensor is deprecated and will be removed"
" in a future release. Float8Tensor should be imported directly through "
"`from transformer_engine.pytorch import Float8Tensor`",
DeprecationWarning,
stacklevel=2,
)
__all__ = ["Float8Tensor"]
......@@ -184,7 +184,7 @@ def combine_tensors(
num_tensors = len(tensors)
new_shape = list(tensors[0].shape)
new_shape.insert(dim, num_tensors)
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
if isinstance(tensors[0], Float8Tensor):
new_stride = list(tensors[0]._data.stride())
......@@ -224,7 +224,7 @@ class SplitAlongDim(torch.autograd.Function):
# pylint: disable=missing-function-docstring
ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import (
Float8TensorStorage,
)
......@@ -278,7 +278,7 @@ class SplitAlongDim(torch.autograd.Function):
split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
dims = len(grad_outputs[0].shape)
split_dim = (ctx.split_dim + dims) % dims
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
if isinstance(grad_outputs[0], Float8Tensor):
noop_ok = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment