Unverified Commit 21b780cc authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Enable use of internal tensors in Sequential (#1900)



* Replace `is_float8_tensor` with `is_quantized_tensor`

Replace free function `is_float8_tensor` with `is_quantized_tensor` in `_common.py` and use it throughout the `ops` codebase to check if a tensor is a (possibly internal) quantized tensor
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Pass next and previous op quantizers directly to op_forward and fuser_forward

Change interface of `fuser_forward` and `op_forward` to no longer take preceding and following ops and instead take the following op's input quantizer and preceding op's input gradient's quantizer directly
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Remove use redundant `detach` in `BasicLinear`

Remove use of `detach` in `BasicLinear` for improved performance (enabled by not passing prev_op to backward)
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Handle saving internal tensors

Handle saving internal tensors in `_OperationFuserAutogradFunction` using `prepare_for_saving` and `restore_from_saved`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Use internal tensors

Enable use of internal tensors in `BasicLinear` quantizers and fix issues resulting from internal tensors not having methods that regular tensors have
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 447de6da
......@@ -28,7 +28,6 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
......
......@@ -21,7 +21,6 @@ import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
......@@ -32,6 +31,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
......@@ -370,7 +370,7 @@ def _test_linear(
if quantized_compute:
tols = dtype_tols(
model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight)
if isinstance(model[0].weight, Float8Tensor)
else tex.DType.kFloat8E4M3
)
......
......@@ -19,7 +19,6 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused import (
BackwardLinearAdd,
ForwardLinearBiasActivation,
......
......@@ -5,7 +5,7 @@
"""Helper functions used in fusible operations."""
from __future__ import annotations
from typing import Any, Optional
from typing import Optional
import torch
......@@ -17,16 +17,16 @@ from ..tensor.quantized_tensor import QuantizedTensorBase
from ..utils import canonicalize_dtype
def is_float8_tensor(tensor: Any) -> bool:
"""Check if object is a `Float8Tensor`"""
return isinstance(tensor, Float8Tensor)
def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorBase) -> bool:
"""Check if tensor is a quantized tensor"""
return isinstance(tensor, QuantizedTensorBase)
def maybe_dequantize(
tensor: torch.Tensor | QuantizedTensorBase, dtype: torch.dtype | None = None
) -> torch.Tensor:
"""Dequantize tensor to given dtype or just convert if not a quantized tensor"""
if isinstance(tensor, QuantizedTensorBase):
if is_quantized_tensor(tensor):
return tensor.dequantize(dtype=dtype)
if dtype is not None and tensor.dtype != dtype:
return tensor.to(dtype)
......
......@@ -12,7 +12,7 @@ import torch
import transformer_engine_torch as tex
from ...fp8 import FP8GlobalStateManager
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize
......@@ -71,8 +71,9 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor:
# Compute dtype
......@@ -88,14 +89,10 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
x = maybe_dequantize(input_.contiguous(), dtype)
# Check if quantized compute is enabled
quantized_compute_enabled = FP8GlobalStateManager.is_fp8_enabled()
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
quantizer = None
if (
quantized_compute_enabled
and next_op is not None
and next_op.num_quantizers("forward") > 0
):
quantizer = next_op.get_quantizer("forward", 0)
if with_quantized_compute:
quantizer = next_op_input_quantizer
# Launch kernel
y = self._activation_forward_impl(
......@@ -104,7 +101,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
)
# Check output tensor
if y.dim() != x.dim():
if len(y.size()) != x.dim():
y = y.view(list(x.shape[:-1]) + [-1])
# Quantize input to FP8 before caching if needed
......@@ -114,10 +111,11 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
x = input_quantizer(x)
# Save state for backward pass
ctx.save_for_backward(x.detach())
ctx.quantized_compute_enabled = quantized_compute_enabled
ctx.save_for_backward(x)
ctx.with_quantized_compute = with_quantized_compute
ctx.dtype = dtype
ctx.prev_op = prev_op
ctx.is_first_op = is_first_op
ctx.prev_op_grad_input_quantizer = prev_op_grad_input_quantizer
return y
......@@ -138,12 +136,8 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Check if quantized compute is enabled
quantizer = None
if (
ctx.quantized_compute_enabled
and ctx.prev_op is not None
and ctx.prev_op.num_quantizers("backward") > 0
):
quantizer = ctx.prev_op.get_quantizer("backward", 0)
if ctx.with_quantized_compute:
quantizer = ctx.prev_op_grad_input_quantizer
# Launch kernel
dx = self._activation_backward_impl(
......@@ -157,7 +151,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
dx = dx.view(x.size())
# Clear input tensor if possible
if ctx.prev_op is not None:
if not ctx.is_first_op:
clear_tensor_data(x)
return dx, ()
......
......@@ -15,6 +15,8 @@ from transformer_engine.pytorch.ops.op import (
OperationContext,
)
from transformer_engine.pytorch.tensor import Quantizer
class AddInPlace(BasicOperation):
"""Add in-place
......@@ -57,8 +59,9 @@ class AddInPlace(BasicOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
output = basic_op_extra_inputs[0][0].detach()
......
......@@ -12,6 +12,7 @@ import torch
from ...distributed import gather_along_first_dim
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
class AllGather(BasicOperation):
......@@ -39,8 +40,9 @@ class AllGather(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor:
out: torch.Tensor
if self.process_group_size == 1:
......
......@@ -11,6 +11,7 @@ import torch
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
class AllReduce(BasicOperation):
......@@ -41,8 +42,9 @@ class AllReduce(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor:
# Trivial case
......
......@@ -21,13 +21,13 @@ from ...distributed import (
)
from ...fp8 import FP8GlobalStateManager
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...tensor import Quantizer, QuantizedTensor
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize
from .._common import maybe_dequantize, is_quantized_tensor
from ...utils import (
canonicalize_device,
canonicalize_dtype,
......@@ -272,7 +272,7 @@ class BasicLinear(BasicOperation):
device = canonicalize_device(None)
# Allocate buffer if needed
if isinstance(weight, QuantizedTensor):
if is_quantized_tensor(weight):
weight = torch.empty(
weight.size(),
dtype=weight.dtype,
......@@ -324,6 +324,9 @@ class BasicLinear(BasicOperation):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
input_quantizer.internal = True
weight_quantizer.internal = True
grad_output_quantizer.internal = True
# Recipe-specific configuration
recipe = FP8GlobalStateManager.get_fp8_recipe()
......@@ -463,7 +466,7 @@ class BasicLinear(BasicOperation):
quantizer=input_quantizer,
)
else:
if not isinstance(x_local, QuantizedTensor):
if not is_quantized_tensor(x_local):
x_local = input_quantizer(x_local)
x = x_local
else:
......@@ -482,7 +485,7 @@ class BasicLinear(BasicOperation):
w = weight
if not with_quantized_compute:
w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not isinstance(w, QuantizedTensor):
elif with_quantized_compute and not is_quantized_tensor(w):
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
......@@ -495,7 +498,7 @@ class BasicLinear(BasicOperation):
output_quantizer = None
if tensor_parallel_mode == "row":
output_quantizer = None
elif isinstance(y, QuantizedTensor):
elif is_quantized_tensor(y):
if not with_quantized_compute:
raise ValueError("Output tensor is quantized, but quantized compute is not enabled")
if tensor_parallel_mode == "row":
......@@ -560,18 +563,14 @@ class BasicLinear(BasicOperation):
# Prepare weight tensor for backward pass
if input_requires_grad:
if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensor):
if w is not weight and with_quantized_compute and is_quantized_tensor(w):
w.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
w = None
# Prepare input tensor for backward pass
if weight_requires_grad:
if x_local is input:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach()
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
if with_quantized_compute and is_quantized_tensor(x_local):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
......@@ -664,7 +663,7 @@ class BasicLinear(BasicOperation):
# Check datatype
if dtype is None:
if weight is not None:
if weight is not None and not is_quantized_tensor(weight):
dtype = weight.dtype
else:
dtype = grad_output.dtype
......@@ -692,7 +691,7 @@ class BasicLinear(BasicOperation):
quantizer=grad_output_quantizer,
)
else:
if not isinstance(dy_local, QuantizedTensor):
if not is_quantized_tensor(dy_local):
dy_local = grad_output_quantizer(dy_local)
dy = dy_local
else:
......@@ -727,7 +726,7 @@ class BasicLinear(BasicOperation):
quantizer=input_quantizer,
)
else:
if isinstance(x_local, QuantizedTensor):
if is_quantized_tensor(x_local):
x_local.update_usage(columnwise_usage=True)
else:
x_local = input_quantizer(x_local)
......@@ -754,7 +753,7 @@ class BasicLinear(BasicOperation):
raise ValueError("Weight tensor is required to compute input grad")
w = weight
if with_quantized_compute:
if isinstance(w, QuantizedTensor):
if is_quantized_tensor(w):
w.update_usage(columnwise_usage=True)
else:
if weight_quantizer is None:
......@@ -775,7 +774,7 @@ class BasicLinear(BasicOperation):
grad_input_quantizer = None
if tensor_parallel_mode == "column":
grad_input_quantizer = None
elif isinstance(dx, QuantizedTensor):
elif is_quantized_tensor(dx):
if not with_quantized_compute:
raise ValueError(
"Grad input tensor is quantized, but quantized compute is not enabled"
......@@ -886,12 +885,13 @@ class BasicLinear(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor:
# Check which grads are required
input_requires_grad = ctx.requires_grad and input_.requires_grad
input_requires_grad = ctx.requires_grad
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
# FP8 metadata
......@@ -906,11 +906,9 @@ class BasicLinear(BasicOperation):
# Get quantizers
input_quantizer = self.get_quantizer("forward", 0)
weight_quantizer = self.get_quantizer("forward", 1)
if next_op is not None and next_op.num_quantizers("forward") > 0:
output_quantizer = next_op.get_quantizer("forward", 0)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = self.get_quantizer("backward", 0)
if prev_op is not None and prev_op.num_quantizers("backward") > 0:
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Configure quantizers
# Note: We cache the quantized input for backward pass,
......@@ -949,7 +947,7 @@ class BasicLinear(BasicOperation):
ctx.dtype = dtype
ctx.input_requires_grad = input_requires_grad
ctx.weight_requires_grad = weight_requires_grad
ctx.has_prev_op = prev_op is not None
ctx.has_prev_op = not is_first_op
return output
......
......@@ -17,6 +17,7 @@ from ...utils import (
canonicalize_device,
canonicalize_dtype,
)
from ...tensor import Quantizer
class Bias(BasicOperation):
......@@ -120,8 +121,9 @@ class Bias(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor:
x = input_
b = self.bias.view([1] * (x.dim() - 1) + [self.local_size])
......
......@@ -13,6 +13,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer
class Identity(BasicOperation):
......@@ -22,8 +23,9 @@ class Identity(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor:
return input_
......
......@@ -19,6 +19,7 @@ from ...jit import (
set_jit_fusion_options,
warmup_jit_l2normalization_all_dtypes,
)
from ...tensor import Quantizer
class L2Normalization(BasicOperation):
......@@ -73,8 +74,9 @@ class L2Normalization(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor:
# Use input directly - torch.compile can handle multi-dimensional tensors
x = maybe_dequantize(input_)
......@@ -95,7 +97,7 @@ class L2Normalization(BasicOperation):
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, rsqrt_norm)
ctx.has_prev_op = prev_op is not None
ctx.has_prev_op = not is_first_op
return y
......
......@@ -23,6 +23,7 @@ from ...utils import (
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize
from ...tensor import Quantizer
class LayerNorm(BasicOperation):
......@@ -175,8 +176,9 @@ class LayerNorm(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor:
# Check tensor dims
......@@ -201,12 +203,9 @@ class LayerNorm(BasicOperation):
# Check if output is quantized
output_quantizer = None
if (
FP8GlobalStateManager.is_fp8_enabled()
and next_op is not None
and next_op.num_quantizers("forward") > 0
):
output_quantizer = next_op.get_quantizer("forward", 0)
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
output_quantizer = next_op_input_quantizer
# Compute layer norm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
......@@ -226,7 +225,7 @@ class LayerNorm(BasicOperation):
if requires_grad:
ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None
ctx.has_prev_op = not is_first_op
# Reshape output tensor
out = y.view(input_dims)
......
......@@ -14,6 +14,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer
class MakeExtraOutput(BasicOperation):
......@@ -58,8 +59,9 @@ class MakeExtraOutput(BasicOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
return input_, [(input_,)]
......
......@@ -10,8 +10,9 @@ from typing import Optional
import torch
from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor
from .._common import is_quantized_tensor
from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
class Quantize(BasicOperation):
......@@ -49,8 +50,9 @@ class Quantize(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor:
# Check if FP8 is enabled
......@@ -60,7 +62,7 @@ class Quantize(BasicOperation):
# Quantize if needed
out = input_
if quantize_forward and not isinstance(out, QuantizedTensor):
if quantize_forward and not is_quantized_tensor(out):
out = self.get_quantizer("forward", 0)(out)
ctx.quantize_backward = quantize_backward
......@@ -72,6 +74,6 @@ class Quantize(BasicOperation):
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
grad_input = grad_output
if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor):
if ctx.quantize_backward and not is_quantized_tensor(grad_input):
grad_input = self.get_quantizer("backward", 0)(grad_input)
return grad_input, ()
......@@ -12,6 +12,7 @@ import torch
from ...distributed import gather_along_first_dim
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
from ...tensor import Quantizer
class ReduceScatter(BasicOperation):
......@@ -39,8 +40,9 @@ class ReduceScatter(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor:
# Trivial case
......
......@@ -14,6 +14,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer
class Reshape(BasicOperation):
......@@ -37,8 +38,9 @@ class Reshape(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor:
ctx.input_shape = input_.size()
return input_.reshape(*self._shape)
......
......@@ -23,6 +23,7 @@ from ...utils import (
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, maybe_dequantize
from ...tensor import Quantizer
class RMSNorm(BasicOperation):
......@@ -158,8 +159,9 @@ class RMSNorm(BasicOperation):
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
) -> torch.Tensor:
# Check tensor dims
......@@ -183,12 +185,9 @@ class RMSNorm(BasicOperation):
# Check if output is quantized
output_quantizer = None
if (
FP8GlobalStateManager.is_fp8_enabled()
and next_op is not None
and next_op.num_quantizers("forward") > 0
):
output_quantizer = next_op.get_quantizer("forward", 0)
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
if with_quantized_compute:
output_quantizer = next_op_input_quantizer
# Compute RMSNorm
sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
......@@ -207,7 +206,7 @@ class RMSNorm(BasicOperation):
if requires_grad:
ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None
ctx.has_prev_op = not is_first_op
# Reshape output tensor
out = y.view(input_dims)
......
......@@ -13,11 +13,11 @@ import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...tensor import Quantizer
class ForwardLinearBiasActivation(FusedOperation):
......@@ -59,8 +59,9 @@ class ForwardLinearBiasActivation(FusedOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -83,7 +84,7 @@ class ForwardLinearBiasActivation(FusedOperation):
raise NotImplementedError("Activations are not yet supported")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
......@@ -96,13 +97,9 @@ class ForwardLinearBiasActivation(FusedOperation):
if with_quantized_compute:
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
next_op = basic_op_next_ops[-1]
if next_op is not None and next_op.num_quantizers("forward") > 0:
output_quantizer = next_op.get_quantizer("forward", 0)
output_quantizer = next_op_input_quantizer
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_quantizers("backward") > 0:
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
dtype = None
......@@ -136,7 +133,7 @@ class ForwardLinearBiasActivation(FusedOperation):
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
linear_op_ctx.has_prev_op = not is_first_op
return output, [() for _ in range(len(self.basic_ops))]
......
......@@ -13,11 +13,11 @@ import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias
from transformer_engine.pytorch.ops.op import (
BasicOperation,
FusedOperation,
FusibleOperation,
OperationContext,
)
from transformer_engine.pytorch.tensor import Quantizer
class ForwardLinearBiasAdd(FusedOperation):
......@@ -57,8 +57,9 @@ class ForwardLinearBiasAdd(FusedOperation):
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
prev_op_grad_input_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
is_first_op: bool,
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
......@@ -77,7 +78,7 @@ class ForwardLinearBiasAdd(FusedOperation):
raise ValueError("Bias operation forward does not expect keyword arguments")
# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
input_requires_grad = linear_op_ctx.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
# FP8 metadata
......@@ -91,9 +92,7 @@ class ForwardLinearBiasAdd(FusedOperation):
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_quantizers("backward") > 0:
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
grad_input_quantizer = prev_op_grad_input_quantizer
# Get autocast dtype if needed
dtype = None
......@@ -129,7 +128,7 @@ class ForwardLinearBiasAdd(FusedOperation):
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None
linear_op_ctx.has_prev_op = not is_first_op
return output, [() for _ in range(len(self.basic_ops))]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment