Unverified Commit 9d976bcd authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Minor optimizations to reduce CPU overheads in modules (#1191)



* CPU perf optimization in linear autograd function

Avoid enable_grad context when possible in cast function. Cache distributed group properties.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* CPU perf optimization in prepare_forward function

Avoid torch.nn.Module impl of __setattr__.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid module import in TE module forwards
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use fast getter for params
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Reuse tensor dims in linear autograd func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply optimizations to grouped linear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid deepcopy in tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move _fast_setattr logic to __setattr__ method
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 10cceae9
......@@ -1602,10 +1602,12 @@ def test_gpt_cuda_graph(dtype, bs, model):
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer(
block_args = (
config.hidden_size,
4 * config.hidden_size,
config.num_attention_heads,
)
block_kwargs = dict(
layernorm_epsilon=config.eps,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
......@@ -1617,7 +1619,11 @@ def test_gpt_cuda_graph(dtype, bs, model):
output_layernorm=False,
device="cuda",
)
graphed_block = copy.deepcopy(block)
block = TransformerLayer(*block_args, **block_kwargs)
graphed_block = TransformerLayer(*block_args, **block_kwargs)
with torch.no_grad():
for param1, param2 in zip(block.parameters(), graphed_block.parameters()):
param2.copy_(param1)
out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
......
......@@ -16,6 +16,11 @@ __all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False
def is_cpu_offload_enabled() -> bool:
"""Check if CPU offloading is currently enabled."""
return CPUOffloadEnabled
class CpuOffloadSavedTensorHook:
"""Contex-manager that executes a pair of pack/unpack hooks for saved tensors.
......
......@@ -6,6 +6,7 @@
from __future__ import annotations
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
......@@ -125,6 +126,7 @@ def set_tensor_model_parallel_attributes(
setattr(tensor, "partition_stride", stride)
@lru_cache
def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
"""Return world size for the distributed group."""
if not torch.distributed.is_initialized():
......@@ -132,6 +134,7 @@ def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
return torch.distributed.get_world_size(group=group)
@lru_cache
def get_distributed_rank(group: Optional[dist_group_type] = None) -> int:
"""Return my rank for the distributed group."""
assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
......
......@@ -109,7 +109,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False):
if bias.numel() != 0:
if bias is not None and bias.numel() != 0:
return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp)
......@@ -119,7 +119,7 @@ def bgrad_dgelu_fused(
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False):
if bias.numel() != 0:
if bias is not None and bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp)
......
......@@ -11,7 +11,7 @@ import socket
import fcntl
import struct
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager
import torch
......@@ -406,6 +406,36 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fsdp_wrapped = False
self.fsdp_group = None
self._fp8_workspaces: Dict[str, Float8Tensor] = {}
self.activation_dtype: Optional[torch.dtype] = None
# Fast getter for parameters
# Note: torch.nn.Module does not store parameters like normal
# attrs, but rather in a dict. When attempting to access, the
# module will raise an AttributeError in __getattribute__ and
# call a custom __getattr__. This is unnecessary overhead if
# we know we are accessing a parameter.
self._fast_get_param: Callable[str, torch.nn.Parameter]
self._fast_get_param = self.__dict__["_parameters"].get
# Names of attributes that can be set quickly (see __setattr__
# method)
_fast_setattr_names: Set[str] = {
"activation_dtype",
"fp8",
"fp8_initialized",
"fp8_calibration",
"fp8_parameters",
}
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""Increase or decrease size of amax history based on given `length`.
......@@ -593,7 +623,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return
# All checks after this have already been performed once, thus skip
if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype:
if self.activation_dtype == inp.dtype:
return
dtype = inp.dtype
......@@ -708,10 +738,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
if not allow_non_contiguous:
yield inp.contiguous()
else:
yield inp
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
yield inp
if self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
......
......@@ -39,8 +39,9 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
from ..tensor import Float8Tensor, QuantizedTensor
from ..export import is_in_onnx_export_mode
from ..cpu_offload import is_cpu_offload_enabled
__all__ = ["GroupedLinear"]
......@@ -715,11 +716,11 @@ class GroupedLinear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp:
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
weight_tensors = [self._fast_get_param(f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8:
weight_tensors = [
w.from_float8() if isinstance(w, Float8Tensor) else w for w in weight_tensors
w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors
]
weight_tensors_fp8 = [None] * self.num_gemms
......@@ -746,8 +747,6 @@ class GroupedLinear(TransformerEngineBaseModule):
skip_update_flag=skip_fp8_weight_update,
)
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
linear_fn = _GroupedLinear.apply
args = []
......@@ -763,7 +762,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
self._offsets,
......
......@@ -12,7 +12,6 @@ from torch.nn.parameter import Parameter
from torch.nn import init
import transformer_engine_torch as tex
from .base import TransformerEngineBaseModule
from ..cpp_extensions import (
layernorm_fwd_inf,
)
......@@ -143,6 +142,7 @@ class LayerNorm(torch.nn.Module):
)
)
self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None
self.reset_parameters(defer_init=(device == "meta"))
......@@ -186,8 +186,21 @@ class LayerNorm(torch.nn.Module):
@no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp)
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype
if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply
......
......@@ -46,6 +46,7 @@ from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
from ..tensor import QuantizedTensor
from ..cpu_offload import is_cpu_offload_enabled
__all__ = ["LayerNormLinear"]
......@@ -94,8 +95,9 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
out_features, in_features = weight.shape
inp_shape = inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
if fp8:
assert_dim_for_fp8_exec(inputmat)
......@@ -339,7 +341,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.inp_shape = inp_shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.tp_size = tp_size
......@@ -369,7 +371,7 @@ class _LayerNormLinear(torch.autograd.Function):
out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
out = out.view(-1, *inp_shape[1:-1], out_features)
if return_layernorm_output:
if return_layernorm_output_gathered:
......@@ -1149,7 +1151,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch) as inp:
# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
unfused_weights = [self._fast_get_param(name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
......@@ -1160,11 +1162,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
)
bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names])
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused
# Initialize FP8 weights if needed
weight_fp8 = None
......@@ -1190,8 +1190,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
skip_update_flag=skip_fp8_weight_update,
)
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
args = []
......@@ -1200,8 +1198,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
args = [None]
args += (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self._fast_get_param("layer_norm_weight"),
self._fast_get_param("layer_norm_bias"),
weight_tensor,
weight_fp8,
bias_tensor,
......@@ -1212,7 +1210,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
......
......@@ -54,6 +54,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
from ._common import _apply_normalization
from ..cpu_offload import is_cpu_offload_enabled
__all__ = ["LayerNormMLP"]
......@@ -124,7 +125,8 @@ class _LayerNormMLP(torch.autograd.Function):
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inp_shape = inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
if fp8:
assert_dim_for_fp8_exec(inputmat)
......@@ -433,7 +435,8 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight.weight_offloading = True
fc1_weight.weight_offloading = True
fc2_weight.weight_offloading = True
fc1_bias.weight_offloading = True
if fc1_bias is not None:
fc1_bias.weight_offloading = True
inputmat.activation_offloading = True
if normalization == "LayerNorm":
......@@ -487,7 +490,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.use_fc2_bias = use_fc2_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.inp_shape = inp_shape
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
......@@ -519,11 +522,11 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_out, _ = allreduce(fc2_out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1])
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp.shape)
shape = list(inp_shape)
shape[0] *= tp_size
return fc2_out, ln_out_return.view(shape)
return fc2_out, ln_out_return.view_as(inp)
......@@ -1470,8 +1473,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
# Get weight tensors
fc1_weight = self.fc1_weight
fc2_weight = self.fc2_weight
fc1_weight = self._fast_get_param("fc1_weight")
fc1_bias = self._fast_get_param("fc1_bias")
fc2_weight = self._fast_get_param("fc2_weight")
fc2_bias = self._fast_get_param("fc2_bias")
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.from_float8()
......@@ -1524,8 +1529,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
fwd_fn = _LayerNormMLP.apply
args = []
......@@ -1534,15 +1537,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
args = [None]
args += (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self._fast_get_param("layer_norm_weight"),
self._fast_get_param("layer_norm_bias"),
fc1_weight,
fc1_weight_fp8,
self.fc1_bias,
fc1_bias,
self.use_bias,
fc2_weight,
fc2_weight_fp8,
self.fc2_bias,
fc2_bias,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
is_first_microbatch,
......@@ -1550,7 +1553,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
......@@ -1580,12 +1583,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
out, ln_out = out
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(self.fc2_bias, self.activation_dtype)
out = out + cast_if_needed(fc2_bias, self.activation_dtype)
if self.return_bias:
if self.return_layernorm_output:
return out, cast_if_needed(self.fc2_bias, self.activation_dtype), ln_out
return out, cast_if_needed(self.fc2_bias, self.activation_dtype)
return out, cast_if_needed(fc2_bias, self.activation_dtype), ln_out
return out, cast_if_needed(fc2_bias, self.activation_dtype)
if self.return_layernorm_output:
return out, ln_out
return out
......@@ -48,6 +48,7 @@ from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
from ..tensor import QuantizedTensor
from ..cpu_offload import is_cpu_offload_enabled
__all__ = ["Linear"]
......@@ -87,8 +88,9 @@ class _Linear(torch.autograd.Function):
is_input_fp8 = isinstance(inp, Float8Tensor)
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
out_features, in_features = weight.shape
inp_shape = inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view(-1, in_features)
if fp8:
assert_dim_for_fp8_exec(inputmat)
......@@ -180,7 +182,7 @@ class _Linear(torch.autograd.Function):
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight_fp8.size(0)
dim_size[1] = out_features
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap():
if ub_obj_projout.is_atomic_gemm():
......@@ -200,7 +202,7 @@ class _Linear(torch.autograd.Function):
ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index])
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight_fp8.size(0)
dim_size[1] = out_features
out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device)
_ = fp8_gemm(
......@@ -260,7 +262,7 @@ class _Linear(torch.autograd.Function):
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group)
dim_size[1] = weight.size(0)
dim_size[1] = out_features
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap():
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
......@@ -268,7 +270,7 @@ class _Linear(torch.autograd.Function):
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0)
dim_size[1] = out_features
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
_ = gemm(
......@@ -334,7 +336,7 @@ class _Linear(torch.autograd.Function):
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.inp_shape = inp_shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.ub_overlap_ag = ub_overlap_ag
......@@ -358,7 +360,7 @@ class _Linear(torch.autograd.Function):
out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
return out.view(-1, *inp_shape[1:-1], out_features)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
......@@ -941,7 +943,7 @@ class Linear(TransformerEngineBaseModule):
) as inp:
# Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names]
unfused_weights = [self._fast_get_param(name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
......@@ -952,11 +954,9 @@ class Linear(TransformerEngineBaseModule):
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
)
bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names])
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused
# Initialize FP8 weights if needed
weight_fp8 = None
......@@ -983,8 +983,6 @@ class Linear(TransformerEngineBaseModule):
fsdp_group=self.fsdp_group,
)
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled():
linear_fn = _Linear.apply
args = []
......@@ -1002,7 +1000,7 @@ class Linear(TransformerEngineBaseModule):
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
......
......@@ -11,7 +11,6 @@ import torch
from torch.nn.parameter import Parameter
from torch.nn import init
from .base import TransformerEngineBaseModule
from .. import cpp_extensions as tex
from ..jit import no_torch_dynamo
from ..utils import cast_if_needed
......@@ -146,6 +145,7 @@ class RMSNorm(torch.nn.Module):
)
)
self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None
self.reset_parameters(defer_init=(device == "meta"))
......@@ -185,7 +185,19 @@ class RMSNorm(torch.nn.Module):
"""RMSNorm FWD"""
# Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp)
# Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype
if torch.is_grad_enabled():
fwd_fn = _RMSNorm.apply
......
......@@ -751,7 +751,7 @@ class TransformerLayer(torch.nn.Module):
return output
def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
if drop_path is None and bias.numel() != 0:
if drop_path is None and bias is not None and bias.numel() != 0:
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
......@@ -763,7 +763,7 @@ class TransformerLayer(torch.nn.Module):
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(hidden_state, bias, residual, self.hidden_dropout)
else:
if bias.numel() != 0:
if bias is not None and bias.numel() != 0:
hidden_state = hidden_state + bias
out = torch.nn.functional.dropout(
hidden_state, p=self.hidden_dropout, training=self.training
......
......@@ -218,8 +218,12 @@ def safely_set_viewless_tensor_data(tensor: torch.Tensor, new_data_tensor: torch
def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""Cast tensor to dtype"""
if tensor is None:
return None
if tensor.dtype == dtype:
return tensor
with torch.enable_grad():
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype)
return tensor.to(dtype=dtype)
def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment