"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "3e6859e22f5a3b7969f6068f00e148bd825775ad"
Unverified Commit 9d976bcd authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Minor optimizations to reduce CPU overheads in modules (#1191)



* CPU perf optimization in linear autograd function

Avoid enable_grad context when possible in cast function. Cache distributed group properties.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* CPU perf optimization in prepare_forward function

Avoid torch.nn.Module impl of __setattr__.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid module import in TE module forwards
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use fast getter for params
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Reuse tensor dims in linear autograd func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply optimizations to grouped linear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid deepcopy in tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move _fast_setattr logic to __setattr__ method
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 10cceae9
...@@ -1602,10 +1602,12 @@ def test_gpt_cuda_graph(dtype, bs, model): ...@@ -1602,10 +1602,12 @@ def test_gpt_cuda_graph(dtype, bs, model):
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
block = TransformerLayer( block_args = (
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
)
block_kwargs = dict(
layernorm_epsilon=config.eps, layernorm_epsilon=config.eps,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
...@@ -1617,7 +1619,11 @@ def test_gpt_cuda_graph(dtype, bs, model): ...@@ -1617,7 +1619,11 @@ def test_gpt_cuda_graph(dtype, bs, model):
output_layernorm=False, output_layernorm=False,
device="cuda", device="cuda",
) )
graphed_block = copy.deepcopy(block) block = TransformerLayer(*block_args, **block_kwargs)
graphed_block = TransformerLayer(*block_args, **block_kwargs)
with torch.no_grad():
for param1, param2 in zip(block.parameters(), graphed_block.parameters()):
param2.copy_(param1)
out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False) out, grads = _test_gpt_e2e_cuda_graph(block, bs, dtype, config, False)
graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True) graphed_out, graphed_grads = _test_gpt_e2e_cuda_graph(graphed_block, bs, dtype, config, True)
......
...@@ -16,6 +16,11 @@ __all__ = ["get_cpu_offload_context"] ...@@ -16,6 +16,11 @@ __all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False CPUOffloadEnabled = False
def is_cpu_offload_enabled() -> bool:
"""Check if CPU offloading is currently enabled."""
return CPUOffloadEnabled
class CpuOffloadSavedTensorHook: class CpuOffloadSavedTensorHook:
"""Contex-manager that executes a pair of pack/unpack hooks for saved tensors. """Contex-manager that executes a pair of pack/unpack hooks for saved tensors.
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager, AbstractContextManager, ContextDecorator from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings import warnings
...@@ -125,6 +126,7 @@ def set_tensor_model_parallel_attributes( ...@@ -125,6 +126,7 @@ def set_tensor_model_parallel_attributes(
setattr(tensor, "partition_stride", stride) setattr(tensor, "partition_stride", stride)
@lru_cache
def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int: def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
"""Return world size for the distributed group.""" """Return world size for the distributed group."""
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
...@@ -132,6 +134,7 @@ def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int: ...@@ -132,6 +134,7 @@ def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
return torch.distributed.get_world_size(group=group) return torch.distributed.get_world_size(group=group)
@lru_cache
def get_distributed_rank(group: Optional[dist_group_type] = None) -> int: def get_distributed_rank(group: Optional[dist_group_type] = None) -> int:
"""Return my rank for the distributed group.""" """Return my rank for the distributed group."""
assert torch.distributed.is_initialized(), "torch.distributed is not initialized." assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
......
...@@ -109,7 +109,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: ...@@ -109,7 +109,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_""" """Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
if bias.numel() != 0: if bias is not None and bias.numel() != 0:
return bias_gelu_fused_(inp, bias) return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp) return gelu_fused_(inp)
...@@ -119,7 +119,7 @@ def bgrad_dgelu_fused( ...@@ -119,7 +119,7 @@ def bgrad_dgelu_fused(
) -> Tuple[Optional[torch.Tensor], torch.Tensor]: ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`""" """Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
if bias.numel() != 0: if bias is not None and bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias) return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp) return None, dgelu_fused_(grad_output, inp)
......
...@@ -11,7 +11,7 @@ import socket ...@@ -11,7 +11,7 @@ import socket
import fcntl import fcntl
import struct import struct
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
...@@ -406,6 +406,36 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -406,6 +406,36 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fsdp_wrapped = False self.fsdp_wrapped = False
self.fsdp_group = None self.fsdp_group = None
self._fp8_workspaces: Dict[str, Float8Tensor] = {} self._fp8_workspaces: Dict[str, Float8Tensor] = {}
self.activation_dtype: Optional[torch.dtype] = None
# Fast getter for parameters
# Note: torch.nn.Module does not store parameters like normal
# attrs, but rather in a dict. When attempting to access, the
# module will raise an AttributeError in __getattribute__ and
# call a custom __getattr__. This is unnecessary overhead if
# we know we are accessing a parameter.
self._fast_get_param: Callable[str, torch.nn.Parameter]
self._fast_get_param = self.__dict__["_parameters"].get
# Names of attributes that can be set quickly (see __setattr__
# method)
_fast_setattr_names: Set[str] = {
"activation_dtype",
"fp8",
"fp8_initialized",
"fp8_calibration",
"fp8_parameters",
}
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
"""Increase or decrease size of amax history based on given `length`. """Increase or decrease size of amax history based on given `length`.
...@@ -593,7 +623,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -593,7 +623,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return return
# All checks after this have already been performed once, thus skip # All checks after this have already been performed once, thus skip
if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype: if self.activation_dtype == inp.dtype:
return return
dtype = inp.dtype dtype = inp.dtype
...@@ -708,9 +738,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -708,9 +738,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
if not allow_non_contiguous: if not allow_non_contiguous and not inp.is_contiguous():
yield inp.contiguous() inp = inp.contiguous()
else:
yield inp yield inp
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
......
...@@ -39,8 +39,9 @@ from ..cpp_extensions import ( ...@@ -39,8 +39,9 @@ from ..cpp_extensions import (
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..tensor import Float8Tensor, QuantizedTensor
from ..export import is_in_onnx_export_mode from ..export import is_in_onnx_export_mode
from ..cpu_offload import is_cpu_offload_enabled
__all__ = ["GroupedLinear"] __all__ = ["GroupedLinear"]
...@@ -715,11 +716,11 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -715,11 +716,11 @@ class GroupedLinear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp: with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp:
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] weight_tensors = [self._fast_get_param(f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8: if not self.fp8:
weight_tensors = [ weight_tensors = [
w.from_float8() if isinstance(w, Float8Tensor) else w for w in weight_tensors w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors
] ]
weight_tensors_fp8 = [None] * self.num_gemms weight_tensors_fp8 = [None] * self.num_gemms
...@@ -746,8 +747,6 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -746,8 +747,6 @@ class GroupedLinear(TransformerEngineBaseModule):
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
) )
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled(): if torch.is_grad_enabled():
linear_fn = _GroupedLinear.apply linear_fn = _GroupedLinear.apply
args = [] args = []
...@@ -763,7 +762,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -763,7 +762,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
CPUOffloadEnabled, is_cpu_offload_enabled(),
self.sequence_parallel, self.sequence_parallel,
self.activation_dtype, self.activation_dtype,
self._offsets, self._offsets,
......
...@@ -12,7 +12,6 @@ from torch.nn.parameter import Parameter ...@@ -12,7 +12,6 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import transformer_engine_torch as tex import transformer_engine_torch as tex
from .base import TransformerEngineBaseModule
from ..cpp_extensions import ( from ..cpp_extensions import (
layernorm_fwd_inf, layernorm_fwd_inf,
) )
...@@ -143,6 +142,7 @@ class LayerNorm(torch.nn.Module): ...@@ -143,6 +142,7 @@ class LayerNorm(torch.nn.Module):
) )
) )
self.sequence_parallel = sequence_parallel self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None
self.reset_parameters(defer_init=(device == "meta")) self.reset_parameters(defer_init=(device == "meta"))
...@@ -186,8 +186,21 @@ class LayerNorm(torch.nn.Module): ...@@ -186,8 +186,21 @@ class LayerNorm(torch.nn.Module):
@no_torch_dynamo() @no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor: def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD""" """LayerNorm FWD"""
# Set the activation type for AMP. # Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp) # Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _LayerNorm.apply fwd_fn = _LayerNorm.apply
......
...@@ -46,6 +46,7 @@ from ._common import _apply_normalization, _noop_cat ...@@ -46,6 +46,7 @@ from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode from ..export import is_in_onnx_export_mode
from ..tensor import QuantizedTensor from ..tensor import QuantizedTensor
from ..cpu_offload import is_cpu_offload_enabled
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
...@@ -94,8 +95,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -94,8 +95,9 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() out_features, in_features = weight.shape
assert inp.shape[-1] == in_features, "GEMM not possible" inp_shape = inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
...@@ -339,7 +341,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -339,7 +341,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.use_bias = use_bias ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape ctx.inp_shape = inp_shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size ctx.tp_size = tp_size
...@@ -369,7 +371,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -369,7 +371,7 @@ class _LayerNormLinear(torch.autograd.Function):
out, _ = allreduce(out, tp_group) out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp.shape[1:-1], out.shape[-1]) out = out.view(-1, *inp_shape[1:-1], out_features)
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
...@@ -1149,7 +1151,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1149,7 +1151,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch) as inp: with self.prepare_forward(inp, is_first_microbatch) as inp:
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = [self._fast_get_param(name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8: if self.fp8:
if len(unfused_weights) != 1: if len(unfused_weights) != 1:
...@@ -1160,11 +1162,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1160,11 +1162,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
unfused_weights = [w.dequantize() for w in unfused_weights] unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights) weight_tensor = _noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = _noop_cat( bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names])
[getattr(self, name) for name in self.bias_names],
)
else: else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused
# Initialize FP8 weights if needed # Initialize FP8 weights if needed
weight_fp8 = None weight_fp8 = None
...@@ -1190,8 +1190,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1190,8 +1190,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
skip_update_flag=skip_fp8_weight_update, skip_update_flag=skip_fp8_weight_update,
) )
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply fwd_fn = _LayerNormLinear.apply
args = [] args = []
...@@ -1200,8 +1198,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1200,8 +1198,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
args = [None] args = [None]
args += ( args += (
inp, inp,
self.layer_norm_weight, self._fast_get_param("layer_norm_weight"),
self.layer_norm_bias, self._fast_get_param("layer_norm_bias"),
weight_tensor, weight_tensor,
weight_fp8, weight_fp8,
bias_tensor, bias_tensor,
...@@ -1212,7 +1210,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1212,7 +1210,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
CPUOffloadEnabled, is_cpu_offload_enabled(),
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
self.sequence_parallel, self.sequence_parallel,
......
...@@ -54,6 +54,7 @@ from ..jit import no_torch_dynamo ...@@ -54,6 +54,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
from ._common import _apply_normalization from ._common import _apply_normalization
from ..cpu_offload import is_cpu_offload_enabled
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
...@@ -124,7 +125,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -124,7 +125,8 @@ class _LayerNormMLP(torch.autograd.Function):
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible" inp_shape = inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
...@@ -433,6 +435,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -433,6 +435,7 @@ class _LayerNormMLP(torch.autograd.Function):
ln_weight.weight_offloading = True ln_weight.weight_offloading = True
fc1_weight.weight_offloading = True fc1_weight.weight_offloading = True
fc2_weight.weight_offloading = True fc2_weight.weight_offloading = True
if fc1_bias is not None:
fc1_bias.weight_offloading = True fc1_bias.weight_offloading = True
inputmat.activation_offloading = True inputmat.activation_offloading = True
...@@ -487,7 +490,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -487,7 +490,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.use_fc2_bias = use_fc2_bias ctx.use_fc2_bias = use_fc2_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape ctx.inp_shape = inp_shape
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.bias_gelu_nvfusion = bias_gelu_nvfusion ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
...@@ -519,11 +522,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -519,11 +522,11 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_out, _ = allreduce(fc2_out, tp_group) fc2_out, _ = allreduce(fc2_out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1]) fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp.shape) shape = list(inp_shape)
shape[0] *= tp_size shape[0] *= tp_size
return fc2_out, ln_out_return.view(shape) return fc2_out, ln_out_return.view(shape)
return fc2_out, ln_out_return.view_as(inp) return fc2_out, ln_out_return.view_as(inp)
...@@ -1470,8 +1473,10 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1470,8 +1473,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
# Get weight tensors # Get weight tensors
fc1_weight = self.fc1_weight fc1_weight = self._fast_get_param("fc1_weight")
fc2_weight = self.fc2_weight fc1_bias = self._fast_get_param("fc1_bias")
fc2_weight = self._fast_get_param("fc2_weight")
fc2_bias = self._fast_get_param("fc2_bias")
if not self.fp8: if not self.fp8:
if isinstance(fc1_weight, Float8Tensor): if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.from_float8() fc1_weight = fc1_weight.from_float8()
...@@ -1524,8 +1529,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1524,8 +1529,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False self.bias_gelu_nvfusion = False
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _LayerNormMLP.apply fwd_fn = _LayerNormMLP.apply
args = [] args = []
...@@ -1534,15 +1537,15 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1534,15 +1537,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
args = [None] args = [None]
args += ( args += (
inp, inp,
self.layer_norm_weight, self._fast_get_param("layer_norm_weight"),
self.layer_norm_bias, self._fast_get_param("layer_norm_bias"),
fc1_weight, fc1_weight,
fc1_weight_fp8, fc1_weight_fp8,
self.fc1_bias, fc1_bias,
self.use_bias, self.use_bias,
fc2_weight, fc2_weight,
fc2_weight_fp8, fc2_weight_fp8,
self.fc2_bias, fc2_bias,
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
...@@ -1550,7 +1553,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1550,7 +1553,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
CPUOffloadEnabled, is_cpu_offload_enabled(),
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
self.sequence_parallel, self.sequence_parallel,
...@@ -1580,12 +1583,12 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1580,12 +1583,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
out, ln_out = out out, ln_out = out
if self.gemm_bias_unfused_add: if self.gemm_bias_unfused_add:
out = out + cast_if_needed(self.fc2_bias, self.activation_dtype) out = out + cast_if_needed(fc2_bias, self.activation_dtype)
if self.return_bias: if self.return_bias:
if self.return_layernorm_output: if self.return_layernorm_output:
return out, cast_if_needed(self.fc2_bias, self.activation_dtype), ln_out return out, cast_if_needed(fc2_bias, self.activation_dtype), ln_out
return out, cast_if_needed(self.fc2_bias, self.activation_dtype) return out, cast_if_needed(fc2_bias, self.activation_dtype)
if self.return_layernorm_output: if self.return_layernorm_output:
return out, ln_out return out, ln_out
return out return out
...@@ -48,6 +48,7 @@ from ..graph import is_graph_capturing ...@@ -48,6 +48,7 @@ from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode from ..export import is_in_onnx_export_mode
from ..tensor import QuantizedTensor from ..tensor import QuantizedTensor
from ..cpu_offload import is_cpu_offload_enabled
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -87,8 +88,9 @@ class _Linear(torch.autograd.Function): ...@@ -87,8 +88,9 @@ class _Linear(torch.autograd.Function):
is_input_fp8 = isinstance(inp, Float8Tensor) is_input_fp8 = isinstance(inp, Float8Tensor)
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = weight.shape[-1] out_features, in_features = weight.shape
assert inp.shape[-1] == in_features, "GEMM not possible" inp_shape = inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view(-1, in_features) inputmat = inp.view(-1, in_features)
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(inputmat)
...@@ -180,7 +182,7 @@ class _Linear(torch.autograd.Function): ...@@ -180,7 +182,7 @@ class _Linear(torch.autograd.Function):
out = ub_obj_projout.get_ubuf_output(1) out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight_fp8.size(0) dim_size[1] = out_features
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap(): if ub_obj_projout.is_p2p_overlap():
if ub_obj_projout.is_atomic_gemm(): if ub_obj_projout.is_atomic_gemm():
...@@ -200,7 +202,7 @@ class _Linear(torch.autograd.Function): ...@@ -200,7 +202,7 @@ class _Linear(torch.autograd.Function):
ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index]) ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index])
else: else:
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[1] = weight_fp8.size(0) dim_size[1] = out_features
out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device) out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device)
_ = fp8_gemm( _ = fp8_gemm(
...@@ -260,7 +262,7 @@ class _Linear(torch.autograd.Function): ...@@ -260,7 +262,7 @@ class _Linear(torch.autograd.Function):
out = ub_obj_projout.get_ubuf_output(1) out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group) dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group)
dim_size[1] = weight.size(0) dim_size[1] = out_features
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap(): if ub_obj_projout.is_p2p_overlap():
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
...@@ -268,7 +270,7 @@ class _Linear(torch.autograd.Function): ...@@ -268,7 +270,7 @@ class _Linear(torch.autograd.Function):
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
else: else:
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0) dim_size[1] = out_features
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
_ = gemm( _ = gemm(
...@@ -334,7 +336,7 @@ class _Linear(torch.autograd.Function): ...@@ -334,7 +336,7 @@ class _Linear(torch.autograd.Function):
ctx.use_bias = use_bias ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape ctx.inp_shape = inp_shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.ub_overlap_ag = ub_overlap_ag ctx.ub_overlap_ag = ub_overlap_ag
...@@ -358,7 +360,7 @@ class _Linear(torch.autograd.Function): ...@@ -358,7 +360,7 @@ class _Linear(torch.autograd.Function):
out, _ = allreduce(out, tp_group) out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1]) return out.view(-1, *inp_shape[1:-1], out_features)
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
...@@ -941,7 +943,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -941,7 +943,7 @@ class Linear(TransformerEngineBaseModule):
) as inp: ) as inp:
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = [self._fast_get_param(name) for name in self.weight_names]
if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8: if self.fp8:
if len(unfused_weights) != 1: if len(unfused_weights) != 1:
...@@ -952,11 +954,9 @@ class Linear(TransformerEngineBaseModule): ...@@ -952,11 +954,9 @@ class Linear(TransformerEngineBaseModule):
unfused_weights = [w.dequantize() for w in unfused_weights] unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = _noop_cat(unfused_weights) weight_tensor = _noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = _noop_cat( bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names])
[getattr(self, name) for name in self.bias_names],
)
else: else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused
# Initialize FP8 weights if needed # Initialize FP8 weights if needed
weight_fp8 = None weight_fp8 = None
...@@ -983,8 +983,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -983,8 +983,6 @@ class Linear(TransformerEngineBaseModule):
fsdp_group=self.fsdp_group, fsdp_group=self.fsdp_group,
) )
from ..cpu_offload import CPUOffloadEnabled
if torch.is_grad_enabled(): if torch.is_grad_enabled():
linear_fn = _Linear.apply linear_fn = _Linear.apply
args = [] args = []
...@@ -1002,7 +1000,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1002,7 +1000,7 @@ class Linear(TransformerEngineBaseModule):
self.fp8_calibration, self.fp8_calibration,
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
CPUOffloadEnabled, is_cpu_offload_enabled(),
self.tp_group, self.tp_group,
self.tp_size, self.tp_size,
self.sequence_parallel, self.sequence_parallel,
......
...@@ -11,7 +11,6 @@ import torch ...@@ -11,7 +11,6 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
from .base import TransformerEngineBaseModule
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..utils import cast_if_needed from ..utils import cast_if_needed
...@@ -146,6 +145,7 @@ class RMSNorm(torch.nn.Module): ...@@ -146,6 +145,7 @@ class RMSNorm(torch.nn.Module):
) )
) )
self.sequence_parallel = sequence_parallel self.sequence_parallel = sequence_parallel
self.activation_dtype: Optional[torch.dtype] = None
self.reset_parameters(defer_init=(device == "meta")) self.reset_parameters(defer_init=(device == "meta"))
...@@ -185,7 +185,19 @@ class RMSNorm(torch.nn.Module): ...@@ -185,7 +185,19 @@ class RMSNorm(torch.nn.Module):
"""RMSNorm FWD""" """RMSNorm FWD"""
# Set the activation type for AMP. # Set the activation type for AMP.
TransformerEngineBaseModule.set_activation_dtype(self, inp) # Note: This will soon be deprecated with
# https://github.com/NVIDIA/TransformerEngine/pull/1033
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
elif self.activation_dtype != inp.dtype:
dtype = inp.dtype
for name, param in self.named_parameters():
if param is not None:
assert dtype == param.dtype, (
"Data types for parameters must match when outside of autocasted region. "
f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
)
self.activation_dtype = dtype
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _RMSNorm.apply fwd_fn = _RMSNorm.apply
......
...@@ -751,7 +751,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -751,7 +751,7 @@ class TransformerLayer(torch.nn.Module):
return output return output
def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None): def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
if drop_path is None and bias.numel() != 0: if drop_path is None and bias is not None and bias.numel() != 0:
if self.bias_dropout_fusion: if self.bias_dropout_fusion:
if self.training: if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train bias_dropout_add_func = bias_dropout_add_fused_train
...@@ -763,7 +763,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -763,7 +763,7 @@ class TransformerLayer(torch.nn.Module):
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(hidden_state, bias, residual, self.hidden_dropout) output = bias_dropout_add_func(hidden_state, bias, residual, self.hidden_dropout)
else: else:
if bias.numel() != 0: if bias is not None and bias.numel() != 0:
hidden_state = hidden_state + bias hidden_state = hidden_state + bias
out = torch.nn.functional.dropout( out = torch.nn.functional.dropout(
hidden_state, p=self.hidden_dropout, training=self.training hidden_state, p=self.hidden_dropout, training=self.training
......
...@@ -218,8 +218,12 @@ def safely_set_viewless_tensor_data(tensor: torch.Tensor, new_data_tensor: torch ...@@ -218,8 +218,12 @@ def safely_set_viewless_tensor_data(tensor: torch.Tensor, new_data_tensor: torch
def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""Cast tensor to dtype""" """Cast tensor to dtype"""
if tensor is None:
return None
if tensor.dtype == dtype:
return tensor
with torch.enable_grad(): with torch.enable_grad():
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype) return tensor.to(dtype=dtype)
def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool: def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment