Unverified Commit 8d0187f1 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Use internal quantizer in Linear module (#1638)



* Changes to Linear
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Removing unnecessary check
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Relax the absolute tolerance in FP32 distributed test
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add QuantizedTensorBase class
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Change the blockwise tensor.
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* A little cleaning
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent fe31af80
...@@ -173,7 +173,7 @@ def _get_tolerances(dtype): ...@@ -173,7 +173,7 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5} return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32: if dtype == torch.float32:
return {"rtol": 1.3e-6, "atol": 1e-5} return {"rtol": 1.3e-6, "atol": 4e-5}
raise ValueError(f"Unsupported dtype ({dtype})") raise ValueError(f"Unsupported dtype ({dtype})")
......
...@@ -9,6 +9,8 @@ from typing import Any, Dict, Optional ...@@ -9,6 +9,8 @@ from typing import Any, Dict, Optional
import torch import torch
from .tensor.quantized_tensor import QuantizedTensorBase
from .tensor.float8_tensor import Float8Tensor from .tensor.float8_tensor import Float8Tensor
__all__ = ["get_cpu_offload_context"] __all__ = ["get_cpu_offload_context"]
...@@ -342,7 +344,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -342,7 +344,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
), ),
) )
is_quantized_tensor = callable(getattr(tensor, "prepare_for_saving", None)) is_quantized_tensor = isinstance(tensor, QuantizedTensorBase)
if not torch_stray_tensor: if not torch_stray_tensor:
......
...@@ -15,6 +15,7 @@ import transformer_engine_torch as tex ...@@ -15,6 +15,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from .base import ( from .base import (
get_workspace, get_workspace,
get_ub, get_ub,
...@@ -58,6 +59,7 @@ from ..jit import no_torch_dynamo ...@@ -58,6 +59,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
QuantizedTensorBase,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -164,7 +166,7 @@ class _Linear(torch.autograd.Function): ...@@ -164,7 +166,7 @@ class _Linear(torch.autograd.Function):
inputmat, tp_group, quantizer=input_quantizer inputmat, tp_group, quantizer=input_quantizer
) )
else: else:
if not isinstance(inputmat, QuantizedTensor): if not isinstance(inputmat, QuantizedTensorBase):
columnwise_usage = backward_needs_input and isinstance( columnwise_usage = backward_needs_input and isinstance(
input_quantizer, MXFP8Quantizer input_quantizer, MXFP8Quantizer
) )
...@@ -191,7 +193,7 @@ class _Linear(torch.autograd.Function): ...@@ -191,7 +193,7 @@ class _Linear(torch.autograd.Function):
rowwise=True, rowwise=True,
columnwise=backward_needs_input, columnwise=backward_needs_input,
) )
if not isinstance(inputmat, QuantizedTensor): if not isinstance(inputmat, QuantizedTensorBase):
inputmat = input_quantizer(inputmat) inputmat = input_quantizer(inputmat)
own_quantized_input = True own_quantized_input = True
elif backward_needs_input: elif backward_needs_input:
...@@ -257,7 +259,7 @@ class _Linear(torch.autograd.Function): ...@@ -257,7 +259,7 @@ class _Linear(torch.autograd.Function):
ub_obj = get_ub(ub_name + "_fprop") ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS ub_type = tex.CommOverlapType.RS
out_shape = [reduce(multiply_op, inp.shape[:-1]) // tp_world_size, out_features] out_shape = [reduce(multiply_op, inp.shape[:-1]) // tp_world_size, out_features]
rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device) rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)
elif ub_overlap_ag_fprop: elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop") ub_obj = get_ub(ub_name + "_fprop")
...@@ -297,19 +299,19 @@ class _Linear(torch.autograd.Function): ...@@ -297,19 +299,19 @@ class _Linear(torch.autograd.Function):
) )
if backward_needs_input: if backward_needs_input:
if own_quantized_input and isinstance(inputmat, QuantizedTensor): if own_quantized_input and isinstance(inputmat, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is # For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data # to gather the input. For MXFP8, columnwise only data
# can be allgathered. # can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if force_hp_input_gather: if force_hp_input_gather:
assert not isinstance(inputmat, QuantizedTensor) assert not isinstance(inputmat, QuantizedTensorBase)
saved_inputmat = inputmat saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad: if inp.requires_grad:
if isinstance(weightmat, QuantizedTensor): if isinstance(weightmat, QuantizedTensorBase):
weightmat.update_usage(columnwise_usage=True) weightmat.update_usage(columnwise_usage=True)
if cpu_offloading and saved_inputmat is not None: if cpu_offloading and saved_inputmat is not None:
...@@ -322,7 +324,7 @@ class _Linear(torch.autograd.Function): ...@@ -322,7 +324,7 @@ class _Linear(torch.autograd.Function):
ctx.fsdp_shapes = _fsdp_scatter_tensors( ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group, fsdp_group,
saved_inputmat, saved_inputmat,
weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None, weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None,
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
...@@ -589,7 +591,7 @@ class _Linear(torch.autograd.Function): ...@@ -589,7 +591,7 @@ class _Linear(torch.autograd.Function):
recipe.fp8_gemm_dgrad.use_split_accumulator recipe.fp8_gemm_dgrad.use_split_accumulator
) )
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensor): if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase):
weight_fp8.update_usage( weight_fp8.update_usage(
rowwise_usage=ctx.weight_quantizer.rowwise_usage, rowwise_usage=ctx.weight_quantizer.rowwise_usage,
columnwise_usage=ctx.weight_quantizer.columnwise_usage, columnwise_usage=ctx.weight_quantizer.columnwise_usage,
...@@ -649,7 +651,7 @@ class _Linear(torch.autograd.Function): ...@@ -649,7 +651,7 @@ class _Linear(torch.autograd.Function):
inputmat_total_work.wait() inputmat_total_work.wait()
inputmat_total_work = None inputmat_total_work = None
if ctx.input_quantizer is not None and not isinstance( if ctx.input_quantizer is not None and not isinstance(
inputmat_total, QuantizedTensor inputmat_total, QuantizedTensorBase
): ):
# Async gather in BF16 does not asynchronously # Async gather in BF16 does not asynchronously
# call quantizer after gather. # call quantizer after gather.
...@@ -657,9 +659,9 @@ class _Linear(torch.autograd.Function): ...@@ -657,9 +659,9 @@ class _Linear(torch.autograd.Function):
inputmat_total = ctx.input_quantizer(inputmat_total) inputmat_total = ctx.input_quantizer(inputmat_total)
# Make sure GEMM inputs have required data # Make sure GEMM inputs have required data
if isinstance(inputmat_total, QuantizedTensor): if isinstance(inputmat_total, QuantizedTensorBase):
inputmat_total.update_usage(columnwise_usage=True) inputmat_total.update_usage(columnwise_usage=True)
if isinstance(grad_output, QuantizedTensor): if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True) grad_output.update_usage(columnwise_usage=True)
# Figure out whether to use split accumulator # Figure out whether to use split accumulator
...@@ -760,7 +762,7 @@ class _Linear(torch.autograd.Function): ...@@ -760,7 +762,7 @@ class _Linear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers # Scatter fp8 weight buffers
if ctx.fp8 and not isinstance(weight, QuantizedTensor): if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
_fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return ( return (
wgrad, wgrad,
...@@ -1308,7 +1310,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1308,7 +1310,7 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer = None grad_output_quantizer = None
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = False input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True weight_quantizer.internal = True
if fp8_output: if fp8_output:
......
...@@ -12,12 +12,14 @@ import torch ...@@ -12,12 +12,14 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase
from ...constants import TE_DType_To_Torch from ...constants import TE_DType_To_Torch
from ..quantized_tensor import Quantizer from ..quantized_tensor import Quantizer
class Float8BlockwiseQTensorBase: class Float8BlockwiseQTensorBase(QuantizedTensorBase):
"""Mixin class that holds data attributes of Float8BlockwiseQTensor. """Mixin class that holds data attributes of Float8BlockwiseQTensor.
Float8BlockwiseQTensor inherits from the PyTorch tensor class and this Float8BlockwiseQTensor inherits from the PyTorch tensor class and this
...@@ -293,3 +295,55 @@ class Float8BlockwiseQTensorBase: ...@@ -293,3 +295,55 @@ class Float8BlockwiseQTensorBase:
f"fp8_dtype={self._fp8_dtype}, " f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data}" f"{descriptor}_scaled_data={data}"
) )
def update_usage(
self, rowwise_usage: Optional[bool] = None, columnwise_usage: Optional[bool] = None
):
"""
update_usage can be used to clear out one of two possible copies of the data.
"""
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
assert (
columnwise_usage or rowwise_usage
), "Must retain some data either columnwise or rowwise"
if columnwise_usage and rowwise_usage:
if not self._is_2D_scaled:
# For 1D scaling, we cannot create columnwise data/scale_inv from rowwise
# data/scale_inv because their scale values are different.
assert (
self._rowwise_data is not None
and self._rowwise_scale_inv is not None
and self._columnwise_data is not None
and self._columnwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage."
else:
# For 2D scaling, if columnwise data/scale_inv is None, we can create them from
# rowwise data/scale_inv.
assert (
self._rowwise_data is not None and self._rowwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage because rowwise data is None."
if self._columnwise_data is None or self._columnwise_scale_inv is None:
self._create_columnwise()
return
if rowwise_usage:
assert (
self._rowwise_data is not None and self._rowwise_scale_inv is not None
), "Cannot update to rowwise usage."
self._columnwise_data = None
self._columnwise_scale_inv = None
return
if columnwise_usage:
assert (
self._columnwise_data is not None and self._columnwise_scale_inv is not None
), "Cannot update to columnwise usage."
self._rowwise_data = None
self._rowwise_scale_inv = None
return
return
...@@ -12,10 +12,14 @@ import torch ...@@ -12,10 +12,14 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase
from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer from ..quantized_tensor import Quantizer
from ...utils import is_non_tn_fp8_gemm_supported
class _FromFloat8Func(torch.autograd.Function): class _FromFloat8Func(torch.autograd.Function):
"""Cast from FP8 to other dtype""" """Cast from FP8 to other dtype"""
...@@ -48,7 +52,7 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -48,7 +52,7 @@ class _FromFloat8Func(torch.autograd.Function):
return grad, None return grad, None
class Float8TensorBase: class Float8TensorBase(QuantizedTensorBase):
"""Mixin class that holds data attributes of Float8Tensor. """Mixin class that holds data attributes of Float8Tensor.
Float8Tensor inherits from the PyTorch tensor class and this mixin Float8Tensor inherits from the PyTorch tensor class and this mixin
...@@ -107,7 +111,7 @@ class Float8TensorBase: ...@@ -107,7 +111,7 @@ class Float8TensorBase:
"quantizer": self._quantizer, "quantizer": self._quantizer,
} }
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]:
"""Prepare the tensor base for saving for backward""" """Prepare the tensor base for saving for backward"""
tensors = [self._data, self._transpose, self._scale_inv] tensors = [self._data, self._transpose, self._scale_inv]
self._data = None self._data = None
...@@ -155,3 +159,43 @@ class Float8TensorBase: ...@@ -155,3 +159,43 @@ class Float8TensorBase:
data = data.contiguous() data = data.contiguous()
self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose) self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose)
self._transpose_invalid = False self._transpose_invalid = False
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""
Generate or remove FP8 data based on provided usage. For
FP8, data cannot be generated even if transpose is available.
"""
has_data = self._data is not None
has_data_transpose = self._transpose is not None and not self._transpose_invalid
needs_data = has_data
needs_data_transpose = has_data_transpose
if is_non_tn_fp8_gemm_supported():
if rowwise_usage is not None and rowwise_usage:
needs_data = True
if columnwise_usage is not None and columnwise_usage:
needs_data = True
needs_data_transpose = False
else:
if rowwise_usage is not None:
needs_data = rowwise_usage
if columnwise_usage is not None:
needs_data_transpose = columnwise_usage
# Generate data that is required
if needs_data and not has_data:
raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose")
if needs_data_transpose and not has_data_transpose:
if not has_data:
raise RuntimeError("FP8 data is required to generate FP8 data transpose")
self._create_transpose()
# Delete data that is not required
if not needs_data:
self._data = None
if not needs_data_transpose:
self._transpose = None
self._transpose_invalid = True
...@@ -11,6 +11,8 @@ import torch ...@@ -11,6 +11,8 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorBase
from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer from ..quantized_tensor import Quantizer
...@@ -43,7 +45,7 @@ class _FromMXFP8Func(torch.autograd.Function): ...@@ -43,7 +45,7 @@ class _FromMXFP8Func(torch.autograd.Function):
return grad, None return grad, None
class MXFP8TensorBase: class MXFP8TensorBase(QuantizedTensorBase):
"""Mixin class that holds data attributes of MXFP8Tensor. """Mixin class that holds data attributes of MXFP8Tensor.
MXFP8Tensor inherits from the PyTorch tensor class and this mixin MXFP8Tensor inherits from the PyTorch tensor class and this mixin
...@@ -151,3 +153,48 @@ class MXFP8TensorBase: ...@@ -151,3 +153,48 @@ class MXFP8TensorBase:
f"rowwise_scale_inv={self._rowwise_scale_inv}, " f"rowwise_scale_inv={self._rowwise_scale_inv}, "
")" ")"
) )
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""
For MXFP8, columnwise scaled output is only produced by x2
scaling kernels, so this function only disables usages.
"""
# Default usage is based on available data
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
# Update row-scaled data
if rowwise_usage:
if self._rowwise_data is None:
raise RuntimeError(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data"
)
if self._rowwise_scale_inv is None:
raise RuntimeError(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses"
)
else:
self._rowwise_data = None
self._rowwise_scale_inv = None
# Update column-scaled data
if columnwise_usage:
if self._columnwise_data is None:
raise RuntimeError(
"Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data"
)
if self._columnwise_scale_inv is None:
raise RuntimeError(
"Requested column-wise usage, "
"but MXFP8Tensor is missing column-scaled scale-inverses"
)
else:
self._columnwise_data = None
self._columnwise_scale_inv = None
...@@ -309,58 +309,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -309,58 +309,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
return Float8BlockwiseQTensor.make_like(self) return Float8BlockwiseQTensor.make_like(self)
def update_usage(
self, rowwise_usage: Optional[bool] = None, columnwise_usage: Optional[bool] = None
):
"""
update_usage can be used to clear out one of two possible copies of the data.
"""
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
assert (
columnwise_usage or rowwise_usage
), "Must retain some data either columnwise or rowwise"
if columnwise_usage and rowwise_usage:
if not self._is_2D_scaled:
# For 1D scaling, we cannot create columnwise data/scale_inv from rowwise
# data/scale_inv because their scale values are different.
assert (
self._rowwise_data is not None
and self._rowwise_scale_inv is not None
and self._columnwise_data is not None
and self._columnwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage."
else:
# For 2D scaling, if columnwise data/scale_inv is None, we can create them from
# rowwise data/scale_inv.
assert (
self._rowwise_data is not None and self._rowwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage because rowwise data is None."
if self._columnwise_data is None or self._columnwise_scale_inv is None:
self._create_columnwise()
return
if rowwise_usage:
assert (
self._rowwise_data is not None and self._rowwise_scale_inv is not None
), "Cannot update to rowwise usage."
self._columnwise_data = None
self._columnwise_scale_inv = None
return
if columnwise_usage:
assert (
self._columnwise_data is not None and self._columnwise_scale_inv is not None
), "Cannot update to columnwise usage."
self._rowwise_data = None
self._rowwise_scale_inv = None
return
return
def clone(self) -> Float8BlockwiseQTensor: def clone(self) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
rowwise_data = None rowwise_data = None
......
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..utils import canonicalize_process_group, devices_match, is_non_tn_fp8_gemm_supported from ..utils import canonicalize_process_group, devices_match
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..constants import dist_group_type from ..constants import dist_group_type
...@@ -422,43 +422,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -422,43 +422,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
return Float8Tensor.make_like(self) return Float8Tensor.make_like(self)
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
# Figure out what data is available and what is required
has_data = self._data is not None
has_data_transpose = self._transpose is not None and not self._transpose_invalid
needs_data = has_data
needs_data_transpose = has_data_transpose
if is_non_tn_fp8_gemm_supported():
if rowwise_usage is not None and rowwise_usage:
needs_data = True
if columnwise_usage is not None and columnwise_usage:
needs_data = True
needs_data_transpose = False
else:
if rowwise_usage is not None:
needs_data = rowwise_usage
if columnwise_usage is not None:
needs_data_transpose = columnwise_usage
# Generate data that is required
if needs_data and not has_data:
raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose")
if needs_data_transpose and not has_data_transpose:
if not has_data:
raise RuntimeError("FP8 data is required to generate FP8 data transpose")
self._create_transpose()
# Delete data that is not required
if not needs_data:
self._data = None
if not needs_data_transpose:
self._transpose = None
self._transpose_invalid = True
def clone(self) -> Float8Tensor: def clone(self) -> Float8Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
assert self._data is not None assert self._data is not None
......
...@@ -217,51 +217,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -217,51 +217,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
# TODO(ksivamani): Fix the detach bug # TODO(ksivamani): Fix the detach bug
return MXFP8Tensor.make_like(self) return MXFP8Tensor.make_like(self)
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""
For MXFP8, columnwise scaled output is only produced by x2
scaling kernels, so this function only disables usages.
"""
# Default usage is based on available data
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
# Update row-scaled data
if rowwise_usage:
if self._rowwise_data is None:
raise RuntimeError(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data"
)
if self._rowwise_scale_inv is None:
raise RuntimeError(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses"
)
else:
self._rowwise_data = None
self._rowwise_scale_inv = None
# Update column-scaled data
if columnwise_usage:
if self._columnwise_data is None:
raise RuntimeError(
"Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data"
)
if self._columnwise_scale_inv is None:
raise RuntimeError(
"Requested column-wise usage, "
"but MXFP8Tensor is missing column-scaled scale-inverses"
)
else:
self._columnwise_data = None
self._columnwise_scale_inv = None
def clone(self) -> MXFP8Tensor: def clone(self) -> MXFP8Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
assert self._rowwise_data is not None assert self._rowwise_data is not None
......
...@@ -15,9 +15,66 @@ from torch.utils._pytree import tree_map ...@@ -15,9 +15,66 @@ from torch.utils._pytree import tree_map
import transformer_engine_torch as tex import transformer_engine_torch as tex
class QuantizedTensorBase:
r"""Base class for all *TensorBase classes.
This class (and its subclasses) are optimization for when
the full QuantizedTensor is not needed (when it is fully
contained inside torch.autograd function and not visible to
PyTorch's autograd).
When creating a new tensor type X one should create both
XTensorBase class inheriting from QuantizedTensorBase and
XTensor inheriting from XTensorBase and QuantizedTensor.
XTensorBase should contain all data members needed to
implement the functionality of the tensor, while
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
r"""
Generate or remove quantized data based on provided usage.
Parameters
----------
rowwise_usage : Optional[bool[, default = `None`
Whether to create or keep the data needed for using the tensor
in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None`
preserves the original value in the tensor.
columnwise_usage : Optional[bool], default = `None`
Whether to create or keep the data needed for using the tensor
in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as
`None` preserves the original value in the tensor.
"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement update_usage function"
)
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]:
"""Prepare the tensor base for saving for backward"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement prepare_for_saving function"
)
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the tensor base data from the saved tensors list"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement restore_from_saved function"
)
def prepare_for_saving( def prepare_for_saving(
*tensors, *tensors: Union[torch.Tensor, QuantizedTensorBase],
) -> Tuple[list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], Optional[Any]]: ) -> Tuple[
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorBase]]
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only """Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal TensorBase types too.""" the internal TensorBase types too."""
...@@ -35,10 +92,13 @@ def prepare_for_saving( ...@@ -35,10 +92,13 @@ def prepare_for_saving(
def restore_from_saved( def restore_from_saved(
tensors: list[Optional[Any]], tensors: list[Optional[Union[torch.Tensor, QuantizedTensorBase]]],
saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
return_saved_tensors: bool = False, return_saved_tensors: bool = False,
) -> list[Optional[Any]] | tuple[list[Optional[Any]], list[Optional[torch.Tensor]]]: ) -> (
list[Optional[torch.Tensor | QuantizedTensorBase]]
| tuple[list[Optional[torch.Tensor | QuantizedTensorBase]], list[Optional[torch.Tensor]]]
):
"""Recombine the tensor data and metadata during backward pass.""" """Recombine the tensor data and metadata during backward pass."""
tensor_objects = [] tensor_objects = []
for tensor in tensors: for tensor in tensors:
...@@ -294,21 +354,6 @@ class QuantizedTensor(torch.Tensor): ...@@ -294,21 +354,6 @@ class QuantizedTensor(torch.Tensor):
f"{self.__class__.__name__} class does not implement detach function" f"{self.__class__.__name__} class does not implement detach function"
) )
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
"""Indicate to the tensor how it is going to be used
This enables optimizations to memory usage in some cases
where forward and backward passes use the tensor in
different directions.
"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement update_usage function"
)
def clear(self): def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully""" """Deallocate this tensor's memory. Typically not needed and must be used carefully"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment