Unverified Commit 77c37d49 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Normalization ops (#1033)



* Add layer norm op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add FP8 cast op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add tests for linear and layernorm with FP8 output
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* RMSNorm op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Replace LayerNorm module with LayerNorm op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Replace RMSNorm module with RMSNorm op
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add AMP support
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Do not save autograd context if grad mode is disabled

Debugging ONNX export tests.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Forward args in pre_forward func to base op class
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update to use QuantizedTensor class
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @ptrendx

Rename "CastFloat8" op to "Quantize". Add more fine-grained control for SM margin. Add docs for legacy sequence_parallel kwarg.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use weight dtype as default compute dtype
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
parent f20d3ddb
This diff is collapsed.
...@@ -82,6 +82,7 @@ from transformer_engine.pytorch.export import onnx_export ...@@ -82,6 +82,7 @@ from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import checkpoint
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch import ops
from transformer_engine.pytorch import optimizers from transformer_engine.pytorch import optimizers
# Register custom op symbolic ONNX functions # Register custom op symbolic ONNX functions
......
...@@ -3,158 +3,90 @@ ...@@ -3,158 +3,90 @@
# See LICENSE for license information. # See LICENSE for license information.
"""LayerNorm API""" """LayerNorm API"""
import os
import warnings import warnings
from typing import Union, Tuple, Optional from typing import Iterable, Optional, Union
import torch import torch
from torch.nn.parameter import Parameter
from torch.nn import init
import transformer_engine_torch as tex from transformer_engine.pytorch.ops import LayerNorm as _LayerNormOp
from ..cpp_extensions import (
layernorm_fwd_inf,
)
from ..jit import no_torch_dynamo
from ..utils import cast_if_needed
__all__ = ["LayerNorm"] __all__ = ["LayerNorm"]
class _LayerNorm(torch.autograd.Function): class LayerNorm(_LayerNormOp):
"""functional LayerNorm""" r"""Layer Normalization
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
inf_ln_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
activation_dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features))
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if is_grad_enabled:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
else:
ln_out, mu, rsigma = (
layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma
),
None,
None,
)
return ln_out.view_as(inp)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None, None
class LayerNorm(torch.nn.Module):
r"""
Applies Layer Normalization over a mini-batch of inputs as described in Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__ the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of :math:`\gamma` and :math:`\beta` are learnable affine transform
size :attr:`hidden_size` parameters that match the inner-most dimensions of the input
tensor.
Parameters Parameters
---------- ----------
hidden_size : int normalized_shape: int or iterable of int
size of each input sample. Inner dimensions of input tensor
eps : float, default = 1e-5 eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for
sequence_parallel : bool, default = `False` numerical stability
if set to `True`, uses sequence parallelism. device: torch.device, default = default CUDA device
params_dtype : torch.dtype, default = `torch.get_default_dtype()` Tensor device
it controls the type used to allocate the initial parameters. Useful when dtype: torch.dtype, default = default dtype
the model is trained with lower precision and the original FP32 parameters Tensor datatype
would not fit in GPU memory.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and If `True`, the :math:`\gamma` parameter is initialized to zero
the LayerNorm formula changes to and the calculation changes to
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
(1 + \gamma) + \beta
device : Union[torch.device, str], default = "cuda" sm_margin: int or dict, default = 0
The device on which the parameters of the model will be allocated. It is the user's Number of SMs to exclude when launching CUDA kernels. This
responsibility to ensure all parameters are moved to the GPU before running the helps overlap with other kernels, e.g. communication kernels.
forward pass. For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
"inference").
Legacy
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
This is custom logic for Megatron-LM integration.
""" """
def __init__( def __init__(
self, self,
hidden_size: int, normalized_shape: Union[Iterable[int], int],
eps: float = 1e-5, eps: float = 1e-5,
sequence_parallel: bool = False, sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda", **kwargs,
) -> None: ) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype # Handle deprecated options
self.eps = eps if params_dtype is not None:
self.zero_centered_gamma = zero_centered_gamma if "dtype" in kwargs:
self.weight = Parameter( raise RuntimeError(
torch.empty( "Both `dtype` and `params_dtype` (deprecated) kwargs are provided"
hidden_size,
device=device,
dtype=params_dtype,
)
)
self.bias = Parameter(
torch.empty(
hidden_size,
device=device,
dtype=params_dtype,
) )
kwargs["dtype"] = params_dtype
# Initialize layer norm operation
super().__init__(
normalized_shape,
eps=eps,
zero_centered_gamma=zero_centered_gamma,
**kwargs,
) )
self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None
self.reset_parameters(defer_init=device == "meta") # Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
...@@ -164,64 +96,62 @@ class LayerNorm(torch.nn.Module): ...@@ -164,64 +96,62 @@ class LayerNorm(torch.nn.Module):
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
if not self.zero_centered_gamma: self.reset_parameters()
init.ones_(self.weight)
else:
init.zeros_(self.weight)
init.zeros_(self.bias)
def reset_parameters(self, defer_init=False) -> None: def reset_parameters(self, defer_init: Optional[bool] = None) -> None:
"""Init LayerNorm parameters""" """Init LayerNorm parameters"""
if defer_init:
return
if self.weight.device == torch.device("meta"): # Check whether to defer init (deprecated)
self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda")) if defer_init is not None:
setattr(self.weight, "sequence_parallel", self.sequence_parallel) warnings.warn(
init.constant_(self.weight, float(not self.zero_centered_gamma)) "defer_init argument to reset_parameters function is deprecated. Set device to"
' "meta" instead.',
if self.bias.device == torch.device("meta"): DeprecationWarning,
self.bias = torch.nn.Parameter(torch.empty_like(self.bias, device="cuda")) stacklevel=2,
setattr(self.bias, "sequence_parallel", self.sequence_parallel)
init.zeros_(self.bias)
@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Set the activation type for AMP.
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype
if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply
args = []
else:
fwd_fn = _LayerNorm.forward
args = [None]
args += (
inp,
self.weight,
self.bias,
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.inf_ln_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled(),
self.activation_dtype,
) )
if defer_init:
return
return fwd_fn(*args) # Reset parameters
super().reset_parameters()
# Set flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None:
self.weight.sequence_parallel = self.sequence_parallel
self.bias.sequence_parallel = self.sequence_parallel
@property
def fwd_ln_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["forward"]
@fwd_ln_sm_margin.setter
def fwd_ln_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["forward"] = val
@property
def bwd_ln_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["backward"]
@bwd_ln_sm_margin.setter
def bwd_ln_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["backward"] = val
@property
def inf_ln_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["inference"]
@inf_ln_sm_margin.setter
def inf_ln_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["inference"] = val
...@@ -3,221 +3,158 @@ ...@@ -3,221 +3,158 @@
# See LICENSE for license information. # See LICENSE for license information.
"""RMSNorm API""" """RMSNorm API"""
import os
import warnings import warnings
from typing import Union, Tuple, Optional from typing import Iterable, Optional, Union
import torch import torch
from torch.nn.parameter import Parameter
from torch.nn import init
from .. import cpp_extensions as tex
from ..jit import no_torch_dynamo
from ..utils import cast_if_needed
from transformer_engine.pytorch.ops import RMSNorm as _RMSNormOp
__all__ = ["RMSNorm"] __all__ = ["RMSNorm"]
class _RMSNorm(torch.autograd.Function): class RMSNorm(_RMSNormOp):
"""functional RMSNorm""" r"""Root Mean Square Layer Normalization
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
rmsnorm_weight: torch.Tensor,
eps: float,
fwd_rmsnorm_sm_margin: int,
bwd_rmsnorm_sm_margin: int,
inf_rmsnorm_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
activation_dtype: torch.dtype,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
# Make sure input dimensions are compatible
in_features = rmsnorm_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.shape[-1] == in_features, "RMSNorm not possible"
inputmat = inp.view((-1, in_features))
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
rmsnorm_weight = cast_if_needed(rmsnorm_weight, activation_dtype)
if is_grad_enabled:
rmsnorm_out, rsigma = tex.rmsnorm_fwd(
inputmat, rmsnorm_weight, eps, fwd_rmsnorm_sm_margin, zero_centered_gamma
)
ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
else:
rmsnorm_out = tex.rmsnorm_fwd_inf(
inputmat, rmsnorm_weight, eps, inf_rmsnorm_sm_margin, zero_centered_gamma
)
return rmsnorm_out.view_as(inp)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_rmsnorm_out = grad_output.view(inputmat.shape)
dxmat, dgamma = tex.rmsnorm_bwd(
d_rmsnorm_out,
inputmat,
rsigma,
rmsnorm_weight,
ctx.bwd_rmsnorm_sm_margin,
ctx.zero_centered_gamma,
)
return (
dxmat.view(ctx.inp_shape),
dgamma,
None,
None,
None,
None,
None,
None,
None,
)
class RMSNorm(torch.nn.Module): Applies Root Mean Square Layer Normalization over a mini-batch of
r""" inputs as described in the paper
Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
.. math:: .. math::
y = \frac{x}{RMS_\varepsilon(x)} * \gamma y = \frac{x}{\text{RMS}_\varepsilon(x)} * \gamma
where where
.. math:: .. math::
RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2 + \varepsilon} \text{RMS}_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^n x_i^2 + \varepsilon}
:math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size` :math:`\gamma` is a learnable affine transform parameter that
matches the inner-most dimensions of the input tensor.
Parameters Parameters
---------- ----------
hidden_size : int normalized_shape: int or iterable of int
size of each input sample. Inner dimensions of input tensor
eps : float, default = 1e-5 eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability. A value added to the denominator for numerical stability
sequence_parallel : bool, default = `False` device: torch.device, default = default CUDA device
if set to `True`, uses sequence parallelism. Tensor device
params_dtype : torch.dtype, default = `torch.get_default_dtype()` dtype: torch.dtype, default = default dtype
it controls the type used to allocate the initial parameters. Useful when Tensor datatype
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in RMSNorm is initialized to 0 and If `True`, the :math:`\gamma` parameter is initialized to zero
the RMSNorm formula changes to and the calculation changes to
.. math:: .. math::
y = \frac{x}{RMS_\varepsilon(x)} * (1 + \gamma) y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's sm_margin: int, default = 0
responsibility to ensure all parameters are moved to the GPU before running the Number of SMs to exclude when launching CUDA kernels. This
forward pass. helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
"inference").
Legacy
------
sequence_parallel: bool
Set a bool attr named `sequence_parallel` in the parameters.
This is custom logic for Megatron-LM integration.
""" """
def __init__( def __init__(
self, self,
hidden_size: int, normalized_shape: Union[Iterable[int], int],
eps: float = 1e-5, eps: float = 1e-5,
sequence_parallel: bool = False, sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda", **kwargs,
) -> None: ) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype # Handle deprecated options
self.eps = eps if params_dtype is not None:
self.zero_centered_gamma = zero_centered_gamma if "dtype" in kwargs:
self.weight = Parameter( raise RuntimeError(
torch.empty( "Both `dtype` and `params_dtype` (deprecated) kwargs are provided"
hidden_size,
device=device,
dtype=params_dtype,
) )
kwargs["dtype"] = params_dtype
# Initialize RMSNorm operation
super().__init__(
normalized_shape,
eps=eps,
zero_centered_gamma=zero_centered_gamma,
**kwargs,
) )
self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None
self.reset_parameters(defer_init=device == "meta") # Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
# These many SMs are subtracted from the total SM count when calling forward
# and backward RMSNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with RMSNorm.
self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_rmsnorm_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
def reset_rms_norm_parameters(self) -> None: def reset_rms_norm_parameters(self) -> None:
"""Init RMSNorm params""" """Deprecated"""
warnings.warn( warnings.warn(
"This method is deprecated and will be removed in an upcoming release. " "This method is deprecated and will be removed in an upcoming release. "
"Update your code to use RMSNorm.reset_parameters() instead.", "Update your code to use RMSNorm.reset_parameters() instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
if not self.zero_centered_gamma: self.reset_parameters()
init.ones_(self.weight)
else:
init.zeros_(self.weight)
def reset_parameters(self, defer_init=False) -> None: def reset_parameters(self, defer_init: Optional[bool] = None) -> None:
"""Reset RMSNorm parameters""" """Init RMSNorm parameters"""
if defer_init:
return
if self.weight.device == torch.device("meta"): # Check whether to defer init (deprecated)
self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda")) if defer_init is not None:
init.constant_(self.weight, float(not self.zero_centered_gamma)) warnings.warn(
setattr(self.weight, "sequence_parallel", self.sequence_parallel) "defer_init argument to reset_parameters function is deprecated. Set device to"
' "meta" instead.',
@no_torch_dynamo() DeprecationWarning,
def forward(self, inp: torch.Tensor) -> torch.Tensor: stacklevel=2,
# pylint: disable=missing-function-docstring
# Set the activation type for AMP.
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype
if torch.is_grad_enabled():
fwd_fn = _RMSNorm.apply
args = []
else:
fwd_fn = _RMSNorm.forward
args = [None]
args += (
inp,
self.weight,
self.eps,
self.fwd_rmsnorm_sm_margin,
self.bwd_rmsnorm_sm_margin,
self.inf_rmsnorm_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled(),
self.activation_dtype,
) )
if defer_init:
return
return fwd_fn(*args) # Reset parameters
super().reset_parameters()
# Flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None:
self.weight.sequence_parallel = self.sequence_parallel
@property
def fwd_rmsnorm_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("fwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["forward"]
@fwd_rmsnorm_sm_margin.setter
def fwd_rmsnorm_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("fwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["forward"] = val
@property
def bwd_rmsnorm_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("bwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["backward"]
@bwd_rmsnorm_sm_margin.setter
def bwd_rmsnorm_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("bwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["backward"] = val
@property
def inf_rmsnorm_sm_margin(self) -> int:
"""Shim for backward compatibility"""
warnings.warn("inf_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
return self._sm_margins["inference"]
@inf_rmsnorm_sm_margin.setter
def inf_rmsnorm_sm_margin(self, val: int) -> None:
"""Shim for backward compatibility"""
warnings.warn("inf_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2)
self._sm_margins["inference"] = val
...@@ -8,17 +8,7 @@ This operation-based API is experimental and subject to change. ...@@ -8,17 +8,7 @@ This operation-based API is experimental and subject to change.
""" """
from transformer_engine.pytorch.ops.basic import ( from transformer_engine.pytorch.ops.basic import *
AddInPlace,
AllGather,
AllReduce,
BasicLinear,
Bias,
Identity,
MakeExtraOutput,
ReduceScatter,
Reshape,
)
from transformer_engine.pytorch.ops.linear import Linear from transformer_engine.pytorch.ops.linear import Linear
from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.op import FusibleOperation
from transformer_engine.pytorch.ops.sequential import Sequential from transformer_engine.pytorch.ops.sequential import Sequential
...@@ -56,6 +56,8 @@ def convert_tensor( ...@@ -56,6 +56,8 @@ def convert_tensor(
if memory_format != torch.preserve_format and not data.is_contiguous( if memory_format != torch.preserve_format and not data.is_contiguous(
memory_format=memory_format memory_format=memory_format
): ):
# Note: torch.Tensor.to ignores memory_format kwarg (see
# https://github.com/pytorch/pytorch/issues/132020).
data = data.contiguous(memory_format=memory_format) data = data.contiguous(memory_format=memory_format)
return Float8Tensor.make_like( return Float8Tensor.make_like(
tensor, tensor,
...@@ -65,7 +67,14 @@ def convert_tensor( ...@@ -65,7 +67,14 @@ def convert_tensor(
) )
# Convert standard PyTorch tensor # Convert standard PyTorch tensor
return tensor.to(device=device, dtype=dtype, memory_format=memory_format) tensor = tensor.to(device=device, dtype=dtype)
if memory_format != torch.preserve_format and not tensor.is_contiguous(
memory_format=memory_format
):
# Note: torch.Tensor.to ignores memory_format kwarg (see
# https://github.com/pytorch/pytorch/issues/132020).
tensor = tensor.contiguous(memory_format=memory_format)
return tensor
def reshape( def reshape(
...@@ -114,3 +123,14 @@ def reshape( ...@@ -114,3 +123,14 @@ def reshape(
# Reshape standard PyTorch tensor # Reshape standard PyTorch tensor
return tensor.view(shape) return tensor.view(shape)
def maybe_autocast_dtype(
*,
device_type: str = "cuda",
default_dtype: Optional[torch.dtype] = None,
) -> torch.dtype:
"""Get autocast dtype if enabled"""
if torch.is_autocast_enabled(device_type):
return torch.get_autocast_dtype(device_type)
return canonicalize_dtype(default_dtype)
...@@ -10,6 +10,9 @@ from .all_reduce import AllReduce ...@@ -10,6 +10,9 @@ from .all_reduce import AllReduce
from .basic_linear import BasicLinear from .basic_linear import BasicLinear
from .bias import Bias from .bias import Bias
from .identity import Identity from .identity import Identity
from .layer_norm import LayerNorm
from .make_extra_output import MakeExtraOutput from .make_extra_output import MakeExtraOutput
from .quantize import Quantize
from .reduce_scatter import ReduceScatter from .reduce_scatter import ReduceScatter
from .reshape import Reshape from .reshape import Reshape
from .rmsnorm import RMSNorm
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusable operation for Layer Normalization."""
from __future__ import annotations
from collections.abc import Iterable
import math
import os
from typing import Optional
import torch
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...cpp_extensions import (
layernorm_fwd_fp8,
layernorm_fwd_fp8_inf,
layernorm_fwd_inf,
)
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape
class LayerNorm(BasicOperation):
r"""Layer Normalization
Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform
parameters that match the inner-most dimensions of the input
tensor.
Parameters
----------
normalized_shape: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator of layer normalization for
numerical stability
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin: int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
"inference").
"""
def __init__(
self,
normalized_shape: Iterable[int] | int,
*,
eps: float = 1e-5,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False,
sm_margin: int | dict[str, int] = 0,
) -> None:
super().__init__()
self.eps: float = eps
self.zero_centered_gamma: bool = zero_centered_gamma
# Parameter shape
if not isinstance(normalized_shape, Iterable):
normalized_shape = (normalized_shape,)
else:
normalized_shape = tuple(normalized_shape)
self._shape: tuple[int, ...] = normalized_shape
# Parameter device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device
# Initialize parameters if needed
dtype = canonicalize_dtype(dtype)
weight = torch.empty(
self._shape,
device="meta",
dtype=dtype,
)
bias = torch.empty(
self._shape,
device="meta",
dtype=dtype,
)
weight = torch.nn.Parameter(weight)
bias = torch.nn.Parameter(bias)
self.weight: torch.nn.Parameter
self.bias: torch.nn.Parameter
self.register_parameter("weight", weight)
self.register_parameter("bias", bias)
if not defer_param_init:
self.reset_parameters()
# Number of SMs to exclude when launching CUDA kernels
self._sm_margins: dict[str, int]
if isinstance(sm_margin, dict):
def getenv(name: str) -> int:
return int(os.getenv(name, "0"))
self._sm_margins = {
"forward": sm_margin.get("forward", getenv("NVTE_FWD_LAYERNORM_SM_MARGIN")),
"backward": sm_margin.get("backward", getenv("NVTE_BWD_LAYERNORM_SM_MARGIN")),
"inference": sm_margin.get("inference", getenv("NVTE_INF_LAYERNORM_SM_MARGIN")),
}
else:
def getenv(name: str) -> int:
return int(os.getenv(name, str(sm_margin)))
self._sm_margins = {
"forward": getenv("NVTE_FWD_LAYERNORM_SM_MARGIN"),
"backward": getenv("NVTE_BWD_LAYERNORM_SM_MARGIN"),
"inference": getenv("NVTE_INF_LAYERNORM_SM_MARGIN"),
}
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""
# Make sure parameter is initialized
weight = self.weight
bias = self.bias
if weight.device.type != "cuda":
weight = torch.empty_like(weight, device=self.device)
else:
weight = weight.to(device=self.device)
if bias.device.type != "cuda":
bias = torch.empty_like(bias, device=self.device)
else:
bias = bias.to(device=self.device)
# Initialize values
if self.zero_centered_gamma:
torch.nn.init.zeros_(weight)
else:
torch.nn.init.ones_(weight)
torch.nn.init.zeros_(bias)
# Save updated parameter
if not isinstance(weight, torch.nn.Parameter):
weight = torch.nn.Parameter(weight)
if not isinstance(bias, torch.nn.Parameter):
bias = torch.nn.Parameter(bias)
self.weight = weight
self.bias = bias
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
if self.weight.device.type == "meta" or self.bias.device.type == "meta":
self.reset_parameters()
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Check tensor dims
input_dims = tuple(input_.size())
if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={self._shape}) are not compatible"
)
# Check input tensors
inner_dim = math.prod(self._shape)
device = self.device
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype)
if isinstance(x, QuantizedTensor):
x = x.dequantize()
if isinstance(w, QuantizedTensor):
w = w.dequantize()
if isinstance(b, QuantizedTensor):
b = b.dequantize()
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if FP8 is enabled
with_fp8_output = (
FP8GlobalStateManager.is_fp8_enabled()
and next_op is not None
and next_op.num_fp8_scales("input") > 0
)
output_fp8_meta = None
if with_fp8_output:
output_fp8_meta = next_op.get_fp8_meta("input")
# Compute layer norm
y = None
means = None
rstdevs = None
sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
if with_fp8_output:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True)
args = (
x,
w,
b,
self.eps,
output_fp8_meta[fp8_meta_key],
0, # fp8_meta_index
fp8_dtype,
sm_margin,
self.zero_centered_gamma,
)
if requires_grad:
data, means, rstdevs = layernorm_fwd_fp8(*args)
else:
data = layernorm_fwd_fp8_inf(*args)
y = Float8Tensor(
data=data,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
args = (
x,
w,
b,
self.eps,
sm_margin,
self.zero_centered_gamma,
)
if requires_grad:
y, means, rstdevs = layernorm_fwd(*args)
else:
y = layernorm_fwd_inf(*args)
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, means, rstdevs)
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None
# Reshape output tensor
out = reshape(y, input_dims)
return out
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
x, means, rstdevs = ctx.saved_tensors
# Check input tensors
inner_dim = x.size(-1)
device = self.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(w, QuantizedTensor):
w = w.dequantize()
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
# Compute layer norm backward pass
dx, dw, db = layernorm_bwd(
dy,
x,
means,
rstdevs,
w,
self._sm_margins["backward"],
self.zero_centered_gamma,
)
# Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x)
clear_tensor_data(means)
clear_tensor_data(rstdevs)
# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, self._shape)
grad_bias = reshape(db, self._shape)
return grad_input, (grad_weight, grad_bias)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for quantization."""
from __future__ import annotations
from typing import Optional
import torch
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
from ..op import BasicOperation, OperationContext
class Quantize(BasicOperation):
"""Quantize tensor data
Uses FP8 recipe from `fp8_autocast` context. When called outside
of an `fp8_autocast` context, this is an identity operation.
Parameters
----------
forward: bool, default = `True`
Perform quantization in forward pass
backward: bool, default = `False`
Perform quantization in backward pass
"""
def __init__(
self,
forward: bool = True,
backward: bool = False,
) -> None:
super().__init__()
self._quantize_forward = forward
self._quantize_backward = backward
def num_fp8_scales(self, mode: str) -> int:
if mode == "input" and self._quantize_forward:
return 1
if mode == "grad_output" and self._quantize_backward:
return 1
return 0
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Check if FP8 is enabled
fp8_enabled = FP8GlobalStateManager.is_fp8_enabled()
quantize_forward = fp8_enabled and self._quantize_forward
quantize_backward = fp8_enabled and self._quantize_backward
# Quantize if needed
out = input_
if quantize_forward and not isinstance(out, QuantizedTensor):
fp8_meta = self.get_fp8_meta("input")
fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
out = Float8Tensor.to_float8(
out,
fp8_meta=fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
ctx.quantize_backward = quantize_backward
return out
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
grad_input = grad_output
if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor):
fp8_meta = self.get_fp8_meta("grad_output")
fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
grad_input = Float8Tensor.to_float8(
grad_input,
fp8_meta=fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
return grad_input, ()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusable operation for RMSNorm."""
from __future__ import annotations
from collections.abc import Iterable
import math
import os
from typing import Optional
import torch
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...cpp_extensions import (
rmsnorm_fwd_fp8,
rmsnorm_fwd_fp8_inf,
rmsnorm_fwd_inf,
)
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape
class RMSNorm(BasicOperation):
r"""Root Mean Square Layer Normalization
Applies Root Mean Square Layer Normalization over a mini-batch of
inputs as described in the paper
`Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma
:math:`\gamma` is a learnable affine transform parameter that
matches the inner-most dimensions of the input tensor.
Parameters
----------
normalized_shape: int or iterable of int
Inner dimensions of input tensor
eps : float, default = 1e-5
A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero
and the calculation changes to
.. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0
Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM
margin at each compute stage ("forward", "backward",
"inference").
"""
def __init__(
self,
normalized_shape: Iterable[int] | int,
*,
eps: float = 1e-5,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False,
sm_margin: int = 0,
) -> None:
super().__init__()
self.eps: float = eps
self.zero_centered_gamma: bool = zero_centered_gamma
# Parameter shape
if not isinstance(normalized_shape, Iterable):
normalized_shape = (normalized_shape,)
else:
normalized_shape = tuple(normalized_shape)
self._shape: tuple[int, ...] = normalized_shape
# Parameter device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device
# Initialize parameters if needed
weight = torch.empty(
self._shape,
device="meta",
dtype=canonicalize_dtype(dtype),
)
weight = torch.nn.Parameter(weight)
self.weight: torch.nn.Parameter
self.register_parameter("weight", weight)
if not defer_param_init:
self.reset_parameters()
# Number of SMs to exclude when launching CUDA kernels
self._sm_margins: dict[str, int]
if isinstance(sm_margin, dict):
def getenv(name: str) -> int:
return int(os.getenv(name, "0"))
self._sm_margins = {
"forward": sm_margin.get("forward", getenv("NVTE_FWD_LAYERNORM_SM_MARGIN")),
"backward": sm_margin.get("backward", getenv("NVTE_BWD_LAYERNORM_SM_MARGIN")),
"inference": sm_margin.get("inference", getenv("NVTE_INF_LAYERNORM_SM_MARGIN")),
}
else:
def getenv(name: str) -> int:
return int(os.getenv(name, str(sm_margin)))
self._sm_margins = {
"forward": getenv("NVTE_FWD_LAYERNORM_SM_MARGIN"),
"backward": getenv("NVTE_BWD_LAYERNORM_SM_MARGIN"),
"inference": getenv("NVTE_INF_LAYERNORM_SM_MARGIN"),
}
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""
# Make sure parameter is initialized
weight = self.weight
if weight.device.type != "cuda":
weight = torch.empty_like(weight, device=self.device)
else:
weight = weight.to(device=self.device)
# Initialize values
if self.zero_centered_gamma:
torch.nn.init.zeros_(weight)
else:
torch.nn.init.ones_(weight)
# Save updated parameter
if not isinstance(weight, torch.nn.Parameter):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_forward(self, *args, **kwargs) -> None:
super().pre_forward(*args, **kwargs)
if self.weight.device.type == "meta":
self.reset_parameters()
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op: Optional[BasicOperation] = None,
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Check tensor dims
input_dims = tuple(input_.size())
if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={self._shape}) are not compatible"
)
# Check input tensors
inner_dim = math.prod(self._shape)
device = self.device
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(x, QuantizedTensor):
x = x.dequantize()
if isinstance(w, QuantizedTensor):
w = w.dequantize()
# Check if backward pass is needed
requires_grad = ctx.requires_grad
# Check if FP8 is enabled
with_fp8_output = (
FP8GlobalStateManager.is_fp8_enabled()
and next_op is not None
and next_op.num_fp8_scales("input") > 0
)
output_fp8_meta = None
if with_fp8_output:
output_fp8_meta = next_op.get_fp8_meta("input")
# Compute RMSNorm
y = None
rstdevs = None
sm_margin = self._sm_margins["forward" if requires_grad else "inference"]
if with_fp8_output:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True)
args = (
x,
w,
self.eps,
output_fp8_meta[fp8_meta_key],
0, # fp8_meta_index
fp8_dtype,
sm_margin,
self.zero_centered_gamma,
)
if requires_grad:
data, rstdevs = rmsnorm_fwd_fp8(*args)
else:
data = rmsnorm_fwd_fp8_inf(*args)
y = Float8Tensor(
data=data,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
args = (
x,
w,
self.eps,
sm_margin,
self.zero_centered_gamma,
)
if requires_grad:
y, rstdevs = rmsnorm_fwd(*args)
else:
y = rmsnorm_fwd_inf(*args)
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, rstdevs)
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None
# Reshape output tensor
out = reshape(y, input_dims)
return out
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
x, rstdevs = ctx.saved_tensors
# Check input tensors
inner_dim = x.size(-1)
device = self.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(w, QuantizedTensor):
w = w.dequantize()
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
# Compute RMSNorm backward pass
dx, dw = rmsnorm_bwd(
dy,
x,
rstdevs,
w,
self._sm_margins["backward"],
self.zero_centered_gamma,
)
# Clear saved tensors if possible
if ctx.has_prev_op:
clear_tensor_data(x)
clear_tensor_data(rstdevs)
# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, self._shape)
return grad_input, (grad_weight,)
...@@ -57,12 +57,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -57,12 +57,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# pylint: disable=unused-argument # pylint: disable=unused-argument
@staticmethod @staticmethod
def forward( def forward(
func_ctx: torch.autograd.function.FunctionCtx, func_ctx: Optional[torch.autograd.function.FunctionCtx],
input_: torch.Tensor, input_: torch.Tensor,
forward_ops: list[tuple[FusibleOperation, list[int]]], forward_ops: list[tuple[FusibleOperation, list[int]]],
backward_ops: list[tuple[FusibleOperation, list[int]]], backward_ops: list[tuple[FusibleOperation, list[int]]],
basic_ops: list[BasicOperation], basic_ops: list[BasicOperation],
basic_op_kwargs: list[dict[str, Any]], basic_op_kwargs: list[dict[str, Any]],
is_grad_enabled: bool,
num_params: int, num_params: int,
num_extra_inputs: int, num_extra_inputs: int,
*params_and_extra_inputs: torch.nn.Parameter, *params_and_extra_inputs: torch.nn.Parameter,
...@@ -120,10 +121,20 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -120,10 +121,20 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
# Apply forward ops # Apply forward ops
x = input_ x = input_
requires_grad = x.requires_grad requires_grad = is_grad_enabled and x.requires_grad
extra_outputs = [None for _ in range(len(basic_ops))] extra_outputs = [None for _ in range(len(basic_ops))]
for op, basic_op_idxs in forward_ops: for op, basic_op_idxs in forward_ops:
# Check if backward op is required
if is_grad_enabled:
if not requires_grad:
requires_grad = any(param.requires_grad for param in op.parameters())
if not requires_grad:
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs:
basic_op_ctxs[idx].requires_grad = requires_grad
x.requires_grad_(requires_grad=requires_grad)
# Forward op # Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs]
...@@ -138,18 +149,12 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -138,18 +149,12 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_next_ops=next_ops, basic_op_next_ops=next_ops,
basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs],
) )
x.requires_grad_(requires_grad=requires_grad)
for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs):
for y in ys:
y.requires_grad_(requires_grad=requires_grad)
extra_outputs[idx] = ys extra_outputs[idx] = ys
# Check if backward op is required
if not requires_grad:
requires_grad = any(param.requires_grad for param in op.parameters())
if not requires_grad:
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs:
basic_op_ctxs[idx]._requires_grad = requires_grad
x.requires_grad_(requires_grad=requires_grad)
# Flatten list of extra outputs # Flatten list of extra outputs
extra_outputs_flat = [] extra_outputs_flat = []
for idx, ys in enumerate(extra_outputs): for idx, ys in enumerate(extra_outputs):
...@@ -163,6 +168,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -163,6 +168,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
) )
extra_outputs_flat.extend(ys) extra_outputs_flat.extend(ys)
# Save context for backward pass
if is_grad_enabled:
# Flatten list of saved tensors # Flatten list of saved tensors
to_save = [] to_save = []
for ctx in basic_op_ctxs: for ctx in basic_op_ctxs:
...@@ -174,7 +182,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -174,7 +182,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
ctx._saved_tensors_range = (range_start, range_end) ctx._saved_tensors_range = (range_start, range_end)
func_ctx.save_for_backward(*to_save) func_ctx.save_for_backward(*to_save)
# Other context for backward pass # Other context
func_ctx.backward_ops = backward_ops func_ctx.backward_ops = backward_ops
func_ctx.basic_ops = basic_ops func_ctx.basic_ops = basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs func_ctx.basic_op_ctxs = basic_op_ctxs
...@@ -224,7 +232,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -224,7 +232,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
for op, basic_op_idxs in backward_ops: for op, basic_op_idxs in backward_ops:
# Stop if no more gradients are required # Stop if no more gradients are required
if all(not basic_op_ctxs[idx]._requires_grad for idx in basic_op_idxs): if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs):
dx = None dx = None
break break
...@@ -282,6 +290,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -282,6 +290,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
None, # backward_ops None, # backward_ops
None, # basic_ops None, # basic_ops
None, # basic_op_kwargs None, # basic_op_kwargs
None, # is_grad_enabled
None, # num_params None, # num_params
None, # num_extra_inputs None, # num_extra_inputs
*grad_params_flat, *grad_params_flat,
...@@ -373,14 +382,23 @@ class OperationFuser: ...@@ -373,14 +382,23 @@ class OperationFuser:
params = [param for op in self._basic_ops for param in op.parameters()] params = [param for op in self._basic_ops for param in op.parameters()]
# Fuser forward pass # Fuser forward pass
return _OperationFuserAutogradFunction.apply( is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
forward_func = _OperationFuserAutogradFunction.apply
args = []
else:
forward_func = _OperationFuserAutogradFunction.forward
args = [None]
args += (
input, input,
self._forward_ops, self._forward_ops,
self._backward_ops, self._backward_ops,
self._basic_ops, self._basic_ops,
basic_op_kwargs, basic_op_kwargs,
is_grad_enabled,
len(params), len(params),
self._num_extra_inputs, self._num_extra_inputs,
*params, *params,
*extra_inputs, *extra_inputs,
) )
return forward_func(*args)
...@@ -43,7 +43,7 @@ class OperationContext: ...@@ -43,7 +43,7 @@ class OperationContext:
_saved_tensors_range: Optional[tuple[int, int]] = None _saved_tensors_range: Optional[tuple[int, int]] = None
# Whether backward pass is required # Whether backward pass is required
_requires_grad: bool = False requires_grad: bool = True
def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None: def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None:
"""Register tensors to be saved for the backward function """Register tensors to be saved for the backward function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment