Unverified Commit 21b780cc authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Enable use of internal tensors in Sequential (#1900)



* Replace `is_float8_tensor` with `is_quantized_tensor`

Replace free function `is_float8_tensor` with `is_quantized_tensor` in `_common.py` and use it throughout the `ops` codebase to check if a tensor is a (possibly internal) quantized tensor
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Pass next and previous op quantizers directly to op_forward and fuser_forward

Change interface of `fuser_forward` and `op_forward` to no longer take preceding and following ops and instead take the following op's input quantizer and preceding op's input gradient's quantizer directly
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Remove use redundant `detach` in `BasicLinear`

Remove use of `detach` in `BasicLinear` for improved performance (enabled by not passing prev_op to backward)
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Handle saving internal tensors

Handle saving internal tensors in `_OperationFuserAutogradFunction` using `prepare_for_saving` and `restore_from_saved`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Use internal tensors

Enable use of internal tensors in `BasicLinear` quantizers and fix issues resulting from internal tensors not having methods that regular tensors have
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 447de6da
...@@ -28,7 +28,6 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -28,7 +28,6 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex import transformer_engine_torch as tex
......
...@@ -21,7 +21,6 @@ import transformer_engine.pytorch as te ...@@ -21,7 +21,6 @@ import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear, UserbuffersBackwardLinear,
UserbuffersForwardLinear, UserbuffersForwardLinear,
...@@ -32,6 +31,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -32,6 +31,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions # Import utility functions
...@@ -370,7 +370,7 @@ def _test_linear( ...@@ -370,7 +370,7 @@ def _test_linear(
if quantized_compute: if quantized_compute:
tols = dtype_tols( tols = dtype_tols(
model[0].weight._fp8_dtype model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight) if isinstance(model[0].weight, Float8Tensor)
else tex.DType.kFloat8E4M3 else tex.DType.kFloat8E4M3
) )
......
...@@ -19,7 +19,6 @@ import transformer_engine.common.recipe ...@@ -19,7 +19,6 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused import ( from transformer_engine.pytorch.ops.fused import (
BackwardLinearAdd, BackwardLinearAdd,
ForwardLinearBiasActivation, ForwardLinearBiasActivation,
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""Helper functions used in fusible operations.""" """Helper functions used in fusible operations."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Optional from typing import Optional
import torch import torch
...@@ -17,16 +17,16 @@ from ..tensor.quantized_tensor import QuantizedTensorBase ...@@ -17,16 +17,16 @@ from ..tensor.quantized_tensor import QuantizedTensorBase
from ..utils import canonicalize_dtype from ..utils import canonicalize_dtype
def is_float8_tensor(tensor: Any) -> bool: def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorBase) -> bool:
"""Check if object is a `Float8Tensor`""" """Check if tensor is a quantized tensor"""
return isinstance(tensor, Float8Tensor) return isinstance(tensor, QuantizedTensorBase)
def maybe_dequantize( def maybe_dequantize(
tensor: torch.Tensor | QuantizedTensorBase, dtype: torch.dtype | None = None tensor: torch.Tensor | QuantizedTensorBase, dtype: torch.dtype | None = None
) -> torch.Tensor: ) -> torch.Tensor:
"""Dequantize tensor to given dtype or just convert if not a quantized tensor""" """Dequantize tensor to given dtype or just convert if not a quantized tensor"""
if isinstance(tensor, QuantizedTensorBase): if is_quantized_tensor(tensor):
return tensor.dequantize(dtype=dtype) return tensor.dequantize(dtype=dtype)
if dtype is not None and tensor.dtype != dtype: if dtype is not None and tensor.dtype != dtype:
return tensor.to(dtype) return tensor.to(dtype)
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ...fp8 import FP8GlobalStateManager from ...fp8 import FP8GlobalStateManager
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize from .._common import maybe_dequantize
...@@ -71,8 +71,9 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -71,8 +71,9 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Compute dtype # Compute dtype
...@@ -88,14 +89,10 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -88,14 +89,10 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
x = maybe_dequantize(input_.contiguous(), dtype) x = maybe_dequantize(input_.contiguous(), dtype)
# Check if quantized compute is enabled # Check if quantized compute is enabled
quantized_compute_enabled = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
quantizer = None quantizer = None
if ( if with_quantized_compute:
quantized_compute_enabled quantizer = next_op_input_quantizer
and next_op is not None
and next_op.num_quantizers("forward") > 0
):
quantizer = next_op.get_quantizer("forward", 0)
# Launch kernel # Launch kernel
y = self._activation_forward_impl( y = self._activation_forward_impl(
...@@ -104,7 +101,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -104,7 +101,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
) )
# Check output tensor # Check output tensor
if y.dim() != x.dim(): if len(y.size()) != x.dim():
y = y.view(list(x.shape[:-1]) + [-1]) y = y.view(list(x.shape[:-1]) + [-1])
# Quantize input to FP8 before caching if needed # Quantize input to FP8 before caching if needed
...@@ -114,10 +111,11 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -114,10 +111,11 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
x = input_quantizer(x) x = input_quantizer(x)
# Save state for backward pass # Save state for backward pass
ctx.save_for_backward(x.detach()) ctx.save_for_backward(x)
ctx.quantized_compute_enabled = quantized_compute_enabled ctx.with_quantized_compute = with_quantized_compute
ctx.dtype = dtype ctx.dtype = dtype
ctx.prev_op = prev_op ctx.is_first_op = is_first_op
ctx.prev_op_grad_input_quantizer = prev_op_grad_input_quantizer
return y return y
...@@ -138,12 +136,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -138,12 +136,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Check if quantized compute is enabled # Check if quantized compute is enabled
quantizer = None quantizer = None
if ( if ctx.with_quantized_compute:
ctx.quantized_compute_enabled quantizer = ctx.prev_op_grad_input_quantizer
and ctx.prev_op is not None
and ctx.prev_op.num_quantizers("backward") > 0
):
quantizer = ctx.prev_op.get_quantizer("backward", 0)
# Launch kernel # Launch kernel
dx = self._activation_backward_impl( dx = self._activation_backward_impl(
...@@ -157,7 +151,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): ...@@ -157,7 +151,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
dx = dx.view(x.size()) dx = dx.view(x.size())
# Clear input tensor if possible # Clear input tensor if possible
if ctx.prev_op is not None: if not ctx.is_first_op:
clear_tensor_data(x) clear_tensor_data(x)
return dx, () return dx, ()
......
...@@ -15,6 +15,8 @@ from transformer_engine.pytorch.ops.op import ( ...@@ -15,6 +15,8 @@ from transformer_engine.pytorch.ops.op import (
OperationContext, OperationContext,
) )
from transformer_engine.pytorch.tensor import Quantizer
class AddInPlace(BasicOperation): class AddInPlace(BasicOperation):
"""Add in-place """Add in-place
...@@ -57,8 +59,9 @@ class AddInPlace(BasicOperation): ...@@ -57,8 +59,9 @@ class AddInPlace(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]], prev_op_grad_input_quantizer: Optional[Quantizer],
basic_op_next_ops: list[Optional[BasicOperation]], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
output = basic_op_extra_inputs[0][0].detach() output = basic_op_extra_inputs[0][0].detach()
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
from ...distributed import gather_along_first_dim from ...distributed import gather_along_first_dim
from .._common import maybe_dequantize from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
class AllGather(BasicOperation): class AllGather(BasicOperation):
...@@ -39,8 +40,9 @@ class AllGather(BasicOperation): ...@@ -39,8 +40,9 @@ class AllGather(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
out: torch.Tensor out: torch.Tensor
if self.process_group_size == 1: if self.process_group_size == 1:
......
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
from .._common import maybe_dequantize from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
class AllReduce(BasicOperation): class AllReduce(BasicOperation):
...@@ -41,8 +42,9 @@ class AllReduce(BasicOperation): ...@@ -41,8 +42,9 @@ class AllReduce(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Trivial case # Trivial case
......
...@@ -21,13 +21,13 @@ from ...distributed import ( ...@@ -21,13 +21,13 @@ from ...distributed import (
) )
from ...fp8 import FP8GlobalStateManager from ...fp8 import FP8GlobalStateManager
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer, QuantizedTensor from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize from .._common import maybe_dequantize, is_quantized_tensor
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
...@@ -272,7 +272,7 @@ class BasicLinear(BasicOperation): ...@@ -272,7 +272,7 @@ class BasicLinear(BasicOperation):
device = canonicalize_device(None) device = canonicalize_device(None)
# Allocate buffer if needed # Allocate buffer if needed
if isinstance(weight, QuantizedTensor): if is_quantized_tensor(weight):
weight = torch.empty( weight = torch.empty(
weight.size(), weight.size(),
dtype=weight.dtype, dtype=weight.dtype,
...@@ -324,6 +324,9 @@ class BasicLinear(BasicOperation): ...@@ -324,6 +324,9 @@ class BasicLinear(BasicOperation):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
input_quantizer.internal = True
weight_quantizer.internal = True
grad_output_quantizer.internal = True
# Recipe-specific configuration # Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
...@@ -463,7 +466,7 @@ class BasicLinear(BasicOperation): ...@@ -463,7 +466,7 @@ class BasicLinear(BasicOperation):
quantizer=input_quantizer, quantizer=input_quantizer,
) )
else: else:
if not isinstance(x_local, QuantizedTensor): if not is_quantized_tensor(x_local):
x_local = input_quantizer(x_local) x_local = input_quantizer(x_local)
x = x_local x = x_local
else: else:
...@@ -482,7 +485,7 @@ class BasicLinear(BasicOperation): ...@@ -482,7 +485,7 @@ class BasicLinear(BasicOperation):
w = weight w = weight
if not with_quantized_compute: if not with_quantized_compute:
w = maybe_dequantize(w, dtype) w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not isinstance(w, QuantizedTensor): elif with_quantized_compute and not is_quantized_tensor(w):
if weight_quantizer is None: if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor") raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
...@@ -495,7 +498,7 @@ class BasicLinear(BasicOperation): ...@@ -495,7 +498,7 @@ class BasicLinear(BasicOperation):
output_quantizer = None output_quantizer = None
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
output_quantizer = None output_quantizer = None
elif isinstance(y, QuantizedTensor): elif is_quantized_tensor(y):
if not with_quantized_compute: if not with_quantized_compute:
raise ValueError("Output tensor is quantized, but quantized compute is not enabled") raise ValueError("Output tensor is quantized, but quantized compute is not enabled")
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
...@@ -560,18 +563,14 @@ class BasicLinear(BasicOperation): ...@@ -560,18 +563,14 @@ class BasicLinear(BasicOperation):
# Prepare weight tensor for backward pass # Prepare weight tensor for backward pass
if input_requires_grad: if input_requires_grad:
if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensor): if w is not weight and with_quantized_compute and is_quantized_tensor(w):
w.update_usage(rowwise_usage=False, columnwise_usage=True) w.update_usage(rowwise_usage=False, columnwise_usage=True)
else: else:
w = None w = None
# Prepare input tensor for backward pass # Prepare input tensor for backward pass
if weight_requires_grad: if weight_requires_grad:
if x_local is input: if with_quantized_compute and is_quantized_tensor(x_local):
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach()
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather): if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data # FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
...@@ -664,7 +663,7 @@ class BasicLinear(BasicOperation): ...@@ -664,7 +663,7 @@ class BasicLinear(BasicOperation):
# Check datatype # Check datatype
if dtype is None: if dtype is None:
if weight is not None: if weight is not None and not is_quantized_tensor(weight):
dtype = weight.dtype dtype = weight.dtype
else: else:
dtype = grad_output.dtype dtype = grad_output.dtype
...@@ -692,7 +691,7 @@ class BasicLinear(BasicOperation): ...@@ -692,7 +691,7 @@ class BasicLinear(BasicOperation):
quantizer=grad_output_quantizer, quantizer=grad_output_quantizer,
) )
else: else:
if not isinstance(dy_local, QuantizedTensor): if not is_quantized_tensor(dy_local):
dy_local = grad_output_quantizer(dy_local) dy_local = grad_output_quantizer(dy_local)
dy = dy_local dy = dy_local
else: else:
...@@ -727,7 +726,7 @@ class BasicLinear(BasicOperation): ...@@ -727,7 +726,7 @@ class BasicLinear(BasicOperation):
quantizer=input_quantizer, quantizer=input_quantizer,
) )
else: else:
if isinstance(x_local, QuantizedTensor): if is_quantized_tensor(x_local):
x_local.update_usage(columnwise_usage=True) x_local.update_usage(columnwise_usage=True)
else: else:
x_local = input_quantizer(x_local) x_local = input_quantizer(x_local)
...@@ -754,7 +753,7 @@ class BasicLinear(BasicOperation): ...@@ -754,7 +753,7 @@ class BasicLinear(BasicOperation):
raise ValueError("Weight tensor is required to compute input grad") raise ValueError("Weight tensor is required to compute input grad")
w = weight w = weight
if with_quantized_compute: if with_quantized_compute:
if isinstance(w, QuantizedTensor): if is_quantized_tensor(w):
w.update_usage(columnwise_usage=True) w.update_usage(columnwise_usage=True)
else: else:
if weight_quantizer is None: if weight_quantizer is None:
...@@ -775,7 +774,7 @@ class BasicLinear(BasicOperation): ...@@ -775,7 +774,7 @@ class BasicLinear(BasicOperation):
grad_input_quantizer = None grad_input_quantizer = None
if tensor_parallel_mode == "column": if tensor_parallel_mode == "column":
grad_input_quantizer = None grad_input_quantizer = None
elif isinstance(dx, QuantizedTensor): elif is_quantized_tensor(dx):
if not with_quantized_compute: if not with_quantized_compute:
raise ValueError( raise ValueError(
"Grad input tensor is quantized, but quantized compute is not enabled" "Grad input tensor is quantized, but quantized compute is not enabled"
...@@ -886,12 +885,13 @@ class BasicLinear(BasicOperation): ...@@ -886,12 +885,13 @@ class BasicLinear(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Check which grads are required # Check which grads are required
input_requires_grad = ctx.requires_grad and input_.requires_grad input_requires_grad = ctx.requires_grad
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
# FP8 metadata # FP8 metadata
...@@ -906,11 +906,9 @@ class BasicLinear(BasicOperation): ...@@ -906,11 +906,9 @@ class BasicLinear(BasicOperation):
# Get quantizers # Get quantizers
input_quantizer = self.get_quantizer("forward", 0) input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1) weight_quantizer = self.get_quantizer("forward", 1)
if next_op is not None and next_op.num_quantizers("forward") > 0: output_quantizer = next_op_input_quantizer
output_quantizer = next_op.get_quantizer("forward", 0)
grad_output_quantizer = self.get_quantizer("backward", 0) grad_output_quantizer = self.get_quantizer("backward", 0)
if prev_op is not None and prev_op.num_quantizers("backward") > 0: grad_input_quantizer = prev_op_grad_input_quantizer
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
# Configure quantizers # Configure quantizers
# Note: We cache the quantized input for backward pass, # Note: We cache the quantized input for backward pass,
...@@ -949,7 +947,7 @@ class BasicLinear(BasicOperation): ...@@ -949,7 +947,7 @@ class BasicLinear(BasicOperation):
ctx.dtype = dtype ctx.dtype = dtype
ctx.input_requires_grad = input_requires_grad ctx.input_requires_grad = input_requires_grad
ctx.weight_requires_grad = weight_requires_grad ctx.weight_requires_grad = weight_requires_grad
ctx.has_prev_op = prev_op is not None ctx.has_prev_op = not is_first_op
return output return output
......
...@@ -17,6 +17,7 @@ from ...utils import ( ...@@ -17,6 +17,7 @@ from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
) )
from ...tensor import Quantizer
class Bias(BasicOperation): class Bias(BasicOperation):
...@@ -120,8 +121,9 @@ class Bias(BasicOperation): ...@@ -120,8 +121,9 @@ class Bias(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
x = input_ x = input_
b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) b = self.bias.view([1] * (x.dim() - 1) + [self.local_size])
......
...@@ -13,6 +13,7 @@ from transformer_engine.pytorch.ops.op import ( ...@@ -13,6 +13,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation, BasicOperation,
OperationContext, OperationContext,
) )
from ...tensor import Quantizer
class Identity(BasicOperation): class Identity(BasicOperation):
...@@ -22,8 +23,9 @@ class Identity(BasicOperation): ...@@ -22,8 +23,9 @@ class Identity(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
return input_ return input_
......
...@@ -19,6 +19,7 @@ from ...jit import ( ...@@ -19,6 +19,7 @@ from ...jit import (
set_jit_fusion_options, set_jit_fusion_options,
warmup_jit_l2normalization_all_dtypes, warmup_jit_l2normalization_all_dtypes,
) )
from ...tensor import Quantizer
class L2Normalization(BasicOperation): class L2Normalization(BasicOperation):
...@@ -73,8 +74,9 @@ class L2Normalization(BasicOperation): ...@@ -73,8 +74,9 @@ class L2Normalization(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Use input directly - torch.compile can handle multi-dimensional tensors # Use input directly - torch.compile can handle multi-dimensional tensors
x = maybe_dequantize(input_) x = maybe_dequantize(input_)
...@@ -95,7 +97,7 @@ class L2Normalization(BasicOperation): ...@@ -95,7 +97,7 @@ class L2Normalization(BasicOperation):
# Save state for backward pass # Save state for backward pass
if requires_grad: if requires_grad:
ctx.save_for_backward(x, rsqrt_norm) ctx.save_for_backward(x, rsqrt_norm)
ctx.has_prev_op = prev_op is not None ctx.has_prev_op = not is_first_op
return y return y
......
...@@ -23,6 +23,7 @@ from ...utils import ( ...@@ -23,6 +23,7 @@ from ...utils import (
) )
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize from .._common import maybe_autocast_dtype, maybe_dequantize
from ...tensor import Quantizer
class LayerNorm(BasicOperation): class LayerNorm(BasicOperation):
...@@ -175,8 +176,9 @@ class LayerNorm(BasicOperation): ...@@ -175,8 +176,9 @@ class LayerNorm(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Check tensor dims # Check tensor dims
...@@ -201,12 +203,9 @@ class LayerNorm(BasicOperation): ...@@ -201,12 +203,9 @@ class LayerNorm(BasicOperation):
# Check if output is quantized # Check if output is quantized
output_quantizer = None output_quantizer = None
if ( with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute:
and next_op is not None output_quantizer = next_op_input_quantizer
and next_op.num_quantizers("forward") > 0
):
output_quantizer = next_op.get_quantizer("forward", 0)
# Compute layer norm # Compute layer norm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"] sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
...@@ -226,7 +225,7 @@ class LayerNorm(BasicOperation): ...@@ -226,7 +225,7 @@ class LayerNorm(BasicOperation):
if requires_grad: if requires_grad:
ctx.save_for_backward(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None ctx.has_prev_op = not is_first_op
# Reshape output tensor # Reshape output tensor
out = y.view(input_dims) out = y.view(input_dims)
......
...@@ -14,6 +14,7 @@ from transformer_engine.pytorch.ops.op import ( ...@@ -14,6 +14,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation, BasicOperation,
OperationContext, OperationContext,
) )
from ...tensor import Quantizer
class MakeExtraOutput(BasicOperation): class MakeExtraOutput(BasicOperation):
...@@ -58,8 +59,9 @@ class MakeExtraOutput(BasicOperation): ...@@ -58,8 +59,9 @@ class MakeExtraOutput(BasicOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]], prev_op_grad_input_quantizer: Optional[Quantizer],
basic_op_next_ops: list[Optional[BasicOperation]], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
return input_, [(input_,)] return input_, [(input_,)]
......
...@@ -10,8 +10,9 @@ from typing import Optional ...@@ -10,8 +10,9 @@ from typing import Optional
import torch import torch
from ...fp8 import FP8GlobalStateManager from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor from .._common import is_quantized_tensor
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
class Quantize(BasicOperation): class Quantize(BasicOperation):
...@@ -49,8 +50,9 @@ class Quantize(BasicOperation): ...@@ -49,8 +50,9 @@ class Quantize(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Check if FP8 is enabled # Check if FP8 is enabled
...@@ -60,7 +62,7 @@ class Quantize(BasicOperation): ...@@ -60,7 +62,7 @@ class Quantize(BasicOperation):
# Quantize if needed # Quantize if needed
out = input_ out = input_
if quantize_forward and not isinstance(out, QuantizedTensor): if quantize_forward and not is_quantized_tensor(out):
out = self.get_quantizer("forward", 0)(out) out = self.get_quantizer("forward", 0)(out)
ctx.quantize_backward = quantize_backward ctx.quantize_backward = quantize_backward
...@@ -72,6 +74,6 @@ class Quantize(BasicOperation): ...@@ -72,6 +74,6 @@ class Quantize(BasicOperation):
grad_output: torch.Tensor, grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]: ) -> tuple[torch.Tensor, tuple[()]]:
grad_input = grad_output grad_input = grad_output
if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor): if ctx.quantize_backward and not is_quantized_tensor(grad_input):
grad_input = self.get_quantizer("backward", 0)(grad_input) grad_input = self.get_quantizer("backward", 0)(grad_input)
return grad_input, () return grad_input, ()
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
from ...distributed import gather_along_first_dim from ...distributed import gather_along_first_dim
from .._common import maybe_dequantize from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
class ReduceScatter(BasicOperation): class ReduceScatter(BasicOperation):
...@@ -39,8 +40,9 @@ class ReduceScatter(BasicOperation): ...@@ -39,8 +40,9 @@ class ReduceScatter(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Trivial case # Trivial case
......
...@@ -14,6 +14,7 @@ from transformer_engine.pytorch.ops.op import ( ...@@ -14,6 +14,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation, BasicOperation,
OperationContext, OperationContext,
) )
from ...tensor import Quantizer
class Reshape(BasicOperation): class Reshape(BasicOperation):
...@@ -37,8 +38,9 @@ class Reshape(BasicOperation): ...@@ -37,8 +38,9 @@ class Reshape(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
ctx.input_shape = input_.size() ctx.input_shape = input_.size()
return input_.reshape(*self._shape) return input_.reshape(*self._shape)
......
...@@ -23,6 +23,7 @@ from ...utils import ( ...@@ -23,6 +23,7 @@ from ...utils import (
) )
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize from .._common import maybe_autocast_dtype, maybe_dequantize
from ...tensor import Quantizer
class RMSNorm(BasicOperation): class RMSNorm(BasicOperation):
...@@ -158,8 +159,9 @@ class RMSNorm(BasicOperation): ...@@ -158,8 +159,9 @@ class RMSNorm(BasicOperation):
self, self,
ctx: OperationContext, ctx: OperationContext,
input_: torch.Tensor, input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None, prev_op_grad_input_quantizer: Optional[Quantizer],
next_op: Optional[BasicOperation] = None, next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor: ) -> torch.Tensor:
# Check tensor dims # Check tensor dims
...@@ -183,12 +185,9 @@ class RMSNorm(BasicOperation): ...@@ -183,12 +185,9 @@ class RMSNorm(BasicOperation):
# Check if output is quantized # Check if output is quantized
output_quantizer = None output_quantizer = None
if ( with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute:
and next_op is not None output_quantizer = next_op_input_quantizer
and next_op.num_quantizers("forward") > 0
):
output_quantizer = next_op.get_quantizer("forward", 0)
# Compute RMSNorm # Compute RMSNorm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"] sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
...@@ -207,7 +206,7 @@ class RMSNorm(BasicOperation): ...@@ -207,7 +206,7 @@ class RMSNorm(BasicOperation):
if requires_grad: if requires_grad:
ctx.save_for_backward(x, rstdevs) ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None ctx.has_prev_op = not is_first_op
# Reshape output tensor # Reshape output tensor
out = y.view(input_dims) out = y.view(input_dims)
......
...@@ -13,11 +13,11 @@ import torch ...@@ -13,11 +13,11 @@ import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import BasicLinear, Bias from transformer_engine.pytorch.ops.basic import BasicLinear, Bias
from transformer_engine.pytorch.ops.op import ( from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusedOperation, FusedOperation,
FusibleOperation, FusibleOperation,
OperationContext, OperationContext,
) )
from ...tensor import Quantizer
class ForwardLinearBiasActivation(FusedOperation): class ForwardLinearBiasActivation(FusedOperation):
...@@ -59,8 +59,9 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -59,8 +59,9 @@ class ForwardLinearBiasActivation(FusedOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]], prev_op_grad_input_quantizer: Optional[Quantizer],
basic_op_next_ops: list[Optional[BasicOperation]], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -83,7 +84,7 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -83,7 +84,7 @@ class ForwardLinearBiasActivation(FusedOperation):
raise NotImplementedError("Activations are not yet supported") raise NotImplementedError("Activations are not yet supported")
# Check which grads are required # Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # FP8 metadata
...@@ -96,13 +97,9 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -96,13 +97,9 @@ class ForwardLinearBiasActivation(FusedOperation):
if with_quantized_compute: if with_quantized_compute:
input_quantizer = linear_op.get_quantizer("forward", 0) input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1) weight_quantizer = linear_op.get_quantizer("forward", 1)
next_op = basic_op_next_ops[-1] output_quantizer = next_op_input_quantizer
if next_op is not None and next_op.num_quantizers("forward") > 0:
output_quantizer = next_op.get_quantizer("forward", 0)
grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_output_quantizer = linear_op.get_quantizer("backward", 0)
prev_op = basic_op_prev_ops[0] grad_input_quantizer = prev_op_grad_input_quantizer
if prev_op is not None and prev_op.num_quantizers("backward") > 0:
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
# Get autocast dtype if needed # Get autocast dtype if needed
dtype = None dtype = None
...@@ -136,7 +133,7 @@ class ForwardLinearBiasActivation(FusedOperation): ...@@ -136,7 +133,7 @@ class ForwardLinearBiasActivation(FusedOperation):
linear_op_ctx.dtype = dtype linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None linear_op_ctx.has_prev_op = not is_first_op
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
...@@ -13,11 +13,11 @@ import torch ...@@ -13,11 +13,11 @@ import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias
from transformer_engine.pytorch.ops.op import ( from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusedOperation, FusedOperation,
FusibleOperation, FusibleOperation,
OperationContext, OperationContext,
) )
from transformer_engine.pytorch.tensor import Quantizer
class ForwardLinearBiasAdd(FusedOperation): class ForwardLinearBiasAdd(FusedOperation):
...@@ -57,8 +57,9 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -57,8 +57,9 @@ class ForwardLinearBiasAdd(FusedOperation):
input_: torch.Tensor, input_: torch.Tensor,
*, *,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]], prev_op_grad_input_quantizer: Optional[Quantizer],
basic_op_next_ops: list[Optional[BasicOperation]], next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
...@@ -77,7 +78,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -77,7 +78,7 @@ class ForwardLinearBiasAdd(FusedOperation):
raise ValueError("Bias operation forward does not expect keyword arguments") raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required # Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata # FP8 metadata
...@@ -91,9 +92,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -91,9 +92,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_quantizer = linear_op.get_quantizer("forward", 0) input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1) weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_output_quantizer = linear_op.get_quantizer("backward", 0)
prev_op = basic_op_prev_ops[0] grad_input_quantizer = prev_op_grad_input_quantizer
if prev_op is not None and prev_op.num_quantizers("backward") > 0:
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
# Get autocast dtype if needed # Get autocast dtype if needed
dtype = None dtype = None
...@@ -129,7 +128,7 @@ class ForwardLinearBiasAdd(FusedOperation): ...@@ -129,7 +128,7 @@ class ForwardLinearBiasAdd(FusedOperation):
linear_op_ctx.dtype = dtype linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None linear_op_ctx.has_prev_op = not is_first_op
return output, [() for _ in range(len(self.basic_ops))] return output, [() for _ in range(len(self.basic_ops))]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment