# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Top level Transformer Engine PyTorch modules""" import os import pickle import warnings from abc import ABC, abstractmethod from typing import Union, Optional, Callable, Tuple, Dict, Any, Mapping, List from functools import partial from contextlib import contextmanager import numpy as np import torch import torch.nn.functional as F from torch.nn.parameter import Parameter from torch.nn import init import transformer_engine_extensions as tex from .fp8 import ( is_fp8_enabled, is_fp8_calibration, get_fp8_recipe, get_fp8_group, get_default_fp8_recipe, get_fp8_te_dtype, is_first_fp8_module, new_fp8_context_id, get_fp8_context_id, set_fp8_context_id, add_amax_to_global_buffer, copy_amax_from_global_buffer, global_amax_reduction, setup_amax_forward_global_reduce_func, amax_and_scale_update, get_global_fp8_buffer, set_global_fp8_buffer, set_amax_buffer_key_deletion, delete_key_from_amax_buffer, copy_forward_fp8_meta_tensors_for_recompute, get_old_fp8_meta_tensors_for_recompute, restore_fp8_meta_tensors, get_amax_reduce_handle_fwd, ) from .jit import ( bias_gelu_fused, bgrad_dgelu_fused, set_jit_fusion_options, warmup_jit_bias_gelu_all_dtypes, ) from .utils import ( divide, get_default_init_method, cast_if_needed, check_dim_for_fp8_forward_exec, ) from .distributed import ( set_tensor_model_parallel_attributes, get_distributed_world_size, allreduce, initialize_affine_weight_gpu, reduce_scatter_along_first_dim, gather_along_first_dim, gather_along_last_dim, is_fp8_activation_recompute_enabled, in_fp8_activation_recompute_phase, ) from .cpp_extensions import ( fp8_gemm, gemm, fp8_cast_transpose_fused, fp8_cast_transpose_bgrad_fused, fp8_gelu, fp8_cast_transpose_bgrad_dgelu_fused, layernorm_fwd_fp8, layernorm_fwd_fp8_inf, layernorm_fwd_inf, cast_to_fp8, cast_from_fp8, ) from .constants import GemmParallelModes, dist_group_type, TE_DType _2X_ACC_FPROP = False _2X_ACC_DGRAD = True _2X_ACC_WGRAD = True _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 _amax_reduce_handle_bwd = None def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: return 33_554_432 return 4_194_304 def get_workspace() -> torch.Tensor: """Returns workspace for cublas.""" global _cublas_workspace if _cublas_workspace is None: _cublas_workspace = torch.empty( get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda" ) return _cublas_workspace @contextmanager def _prepare_backward( fp8: bool, fp8_meta: Dict[str, Any], tp_group: dist_group_type, tp_size: int, name: str = "" ) -> None: """Checks and prep for BWD.""" if fp8: global _amax_reduce_handle_bwd if _amax_reduce_handle_bwd is not None: _amax_reduce_handle_bwd.wait() _amax_reduce_handle_bwd = None # Update amax and scale; Skip all setup for global amax reduction if not fp8_meta["recipe"].reduce_amax: amax_and_scale_update(fp8_meta, False) else: # From previous iteration copy_amax_from_global_buffer(fp8_meta, forward=False) amax_and_scale_update(fp8_meta, False) set_amax_buffer_key_deletion(fp8_meta, forward=False) # Get new backward key. fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) add_amax_to_global_buffer(fp8_meta, forward=False) with torch.cuda.nvtx.range(name + " backward"): yield if fp8 and fp8_meta["recipe"].reduce_amax: if fp8_meta["first_module"]: _amax_reduce_handle_bwd = global_amax_reduction( fp8_meta, tp_group, tp_size, forward=False ) delete_key_from_amax_buffer(forward=False) def initialize_ub( shape: list, tp_size: int, use_fp8: bool = False, ub_cfgs: Optional[dict] = None ) -> None: """Initialize communicators for TP comm overlap using userbuffers.""" global _ub_communicators assert _ub_communicators is None, "UB communicators are already initialized." _ub_communicators = {} rank_id = torch.distributed.get_rank() # Increase the workspace by the number of maximum concurrent streams global _cublas_workspace _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS) # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe fp8_buf = [ "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" ] # Default overlap methods for layers methods = { "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "pipeline":["proj_fprop", "fc2_fprop"], "bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], } def get_method(name): for method, names in methods.items(): if name in names: return method raise KeyError(f"Given layer name {name} does not exist.") def add_ub( name: str, method: str, num_sm: int = 16, cga_size: int = 2, set_sm_margin: int = 0, num_splits: int = 4, aggregate: int = 0, ) -> None: dtype = torch.uint8 if (use_fp8 and name in fp8_buf) else torch.bfloat16 sample_buffer = torch.empty(shape, dtype=dtype, device='cuda') if method == 'ring_exchange': ub_obj = tex.UbufP2PCommOverlap( sample_buffer, # Sample userbuffer rank_id, # Rank id tp_size, # TP size aggregate, # Aggregate 2X GEMM chunks _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams ) else: ub_obj = tex.UbufCommOverlap( sample_buffer, # Sample userbuffer rank_id, # Rank id tp_size, # TP size num_sm, # Number of communication SMs cga_size, # CGA cluster size num_splits, # Number of communication splits set_sm_margin, # Set SM margin _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams ) _ub_communicators[name] = ub_obj for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]): if ub_cfgs is not None and name in ub_cfgs: ub_cfg = ub_cfgs[name] method = ub_cfg["method"] if "method" in ub_cfg else get_method(name) num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16 cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2 num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0 set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0 aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0 add_ub( name, method, num_sm, cga_size, set_sm_margin, num_splits, aggregate ) else: method = get_method(name) if method == "pipeline": add_ub(name, method) else: add_ub(name, method, num_splits=0) def get_ub(name: str): """Get userbuffer communicator corresponding to give key.""" global _ub_communicators assert _ub_communicators is not None, "UB manager is not initialized." assert name in _ub_communicators, f"UB for {name} is not registered." return _ub_communicators[name] class _NoopCat(torch.autograd.Function): """This class is a no-op replacement for `torch.cat`.""" @staticmethod def forward(ctx, full_param_buffer: torch.Tensor, *params_split: Tuple[torch.Tensor, ...], ) -> torch.Tensor: assert not full_param_buffer.requires_grad, "Buffers should not require gradient" assert ( full_param_buffer.shape[0] % len(params_split) == 0 ), "Dimensions not compatible for concatenation" param_temp = full_param_buffer.new() param_temp.set_(full_param_buffer.storage(), full_param_buffer.storage_offset(), full_param_buffer.size(), full_param_buffer.stride()) param_temp.requires_grad = True ctx.save_for_backward(full_param_buffer, *params_split) return param_temp @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: full_param_buffer, *params_split = ctx.saved_tensors split_size = full_param_buffer.shape[0] // len(params_split) grads = [] for i, _ in enumerate(params_split): grads.append(grad_output[i * split_size : (i+1) * split_size]) return None, *grads class TransformerEngineBaseModule(torch.nn.Module, ABC): """Base TE module.""" def __init__(self) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." self.fp8_initialized = False self.fp8 = False self.fp8_calibration = False self.fp8_meta = {} self.fp8_meta["fp8_group"] = None self.fp8_meta["recipe"] = get_default_fp8_recipe() self.fp8_meta_tensors_initialized = False self.tp_group = None self.tp_size = 1 self.sequence_parallel = False self.fp8_weight_shapes = [] self.fp8_meta["autocast_id_fwd_stack"] = [] self.fp8_meta["async_amax_reduction"] = bool( int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) ) def set_meta_tensor(self, fwd: bool) -> None: """Init scales and amaxes for fwd | bwd.""" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" if self.fp8_meta_tensors_initialized: # Handle changed amax history size. curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0] need_len = self.fp8_meta["recipe"].amax_history_len if need_len < curr_len: self.fp8_meta[fp8_meta_tensor_key].amax_history = ( self.fp8_meta[fp8_meta_tensor_key] .amax_history[: self.fp8_meta["recipe"].amax_history_len].clone() ) elif need_len > curr_len: extra_rows = need_len - curr_len self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad( self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows) ) return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd num_fp8_tensors = ( self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 ) self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta() self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones( num_fp8_tensors, dtype=torch.float32, device="cuda" ) self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones( num_fp8_tensors, dtype=torch.float32, device="cuda" ) self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros( self.fp8_meta["recipe"].amax_history_len, num_fp8_tensors, dtype=torch.float32, device="cuda", ) # Needed for calculation of scale inverses to # preserve scale_inv when caching FP8 weights if fwd: # [True, False, True]: -> [input, weight, output] self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor( [True, False, True] * self.fp8_meta["num_gemms"] ).cuda() else: # [True, True]: -> [grad_output, grad_input] self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor( [True, True] * self.fp8_meta["num_gemms"] ).cuda() def init_fp8_meta_tensors(self) -> None: """Init scales and amaxes.""" self.set_meta_tensor(True) self.set_meta_tensor(False) self.fp8_meta_tensors_initialized = True def get_extra_state(self) -> torch.Tensor: """Save before checkpointing.""" state = None if self.fp8 or self.fp8_calibration: state = {} state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history state["global_fp8_buffer"] = get_global_fp8_buffer() # Store other pickelable values. extra = {} for k, v in self.fp8_meta.items(): if isinstance(v, (bool, int, float, str)): extra[k] = v state["extra_fp8_variables"] = extra state_serialized = pickle.dumps(state) state_tensor = torch.tensor(np.frombuffer(state_serialized, dtype=np.uint8)) return state_tensor def set_extra_state(self, state: torch.Tensor) -> None: """Load previous state.""" if state is None: return # Maintain backward compatibility with v0.2.0 and older. if isinstance(state, list): warnings.warn( "This checkpoint format is deprecated and will be" "removed in a future release of Transformer Engine" ) # Retrieve checkpointed items. scale_fwd = state[0] amax_history_fwd = state[1] scale_bwd = state[2] amax_history_bwd = state[3] self.fp8_meta["recipe"].amax_history_len = amax_history_fwd.shape[0] self.fp8_meta["num_gemms"] = ( amax_history_fwd.shape[1] // 2 ) # Two FWD tensors per GEMM # Initialize before loading self.init_fp8_meta_tensors() self.fp8_meta["scaling_fwd"].scale.copy_(scale_fwd) self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd) self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd) self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd) # Restore global FP8 buffer state. set_global_fp8_buffer(state[4]) self.fp8_meta["update_amax_and_scale_fwd"] = state[5] self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6] self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7] self.fp8_meta["autocast_id_fwd"] = state[8] self.fp8_meta["autocast_id_bwd"] = state[9] return if isinstance(state, torch.Tensor): state = pickle.loads(state.detach().cpu().numpy().tobytes()) if state is None: return # Restore global FP8 buffer states. set_global_fp8_buffer(state["global_fp8_buffer"]) # Load extra items. self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta: del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] # Initialize before loading. self.init_fp8_meta_tensors() self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"]) self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"]) self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"]) self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"]) # Backwards compatibility: compute scale inv if it wasn't saved in the extra state. if "scale_inv_fwd" not in state or "scale_inv_bwd" not in state: assert ( "scale_inv_fwd" not in state and "scale_inv_bwd" not in state ), "Invalid state, began saving scale_inv_fwd and scale_inv_bwd at the same time" self.fp8_meta["scaling_fwd"].scale_inv.copy_(1.0/state["scale_fwd"]) self.fp8_meta["scaling_bwd"].scale_inv.copy_(1.0/state["scale_bwd"]) else: self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"]) self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"]) def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" # Native AMP (`torch.autocast`) gets highest priority if torch.is_autocast_enabled(): self.activation_dtype = torch.get_autocast_gpu_dtype() return # All checks after this have already been performed once, thus skip # We assume that user doesn't change input types across iterations if hasattr(self, "activation_dtype"): return assert all( ( (inp.dtype == param.dtype) if param is not None else True for param in self.parameters() ) ), ( "Data type for activations and weights must " "match when outside of autocasted region" ) assert all( ( (inp.dtype == buf.dtype) if buf is not None else True for buf in self.buffers() ) ), ( "Data type for activations and buffers must " "match when outside of autocasted region" ) self.activation_dtype = inp.dtype def set_fp8_weights(self) -> None: """Initializes FP8 weights for the module as class attributes. These are not parameters or buffers since we do not want functions such as `.to(dtype)` or `.to(device)` to effect them. These also do not need to be checkpointed. During `init` phase of the module, the attribute `fp8_weight_shapes` must be populated with the tensor shapes for FP8 weights. This function will iterate over those shapes and initialize respective attributed named `weight1_fp8`, `weight2_fp8`, ... """ if not self.fp8: return for i, shape in enumerate(self.fp8_weight_shapes, start=1): weight_cast_attr = f"weight{i}_fp8" weight_transpose_attr = f"weight{i}_t_fp8" if ( hasattr(self, weight_cast_attr) and getattr(self, weight_cast_attr).shape == shape ): return setattr( self, weight_cast_attr, torch.empty( shape, device=torch.cuda.current_device(), dtype=torch.uint8, ), ) setattr( self, weight_transpose_attr, torch.empty( shape[1], shape[0], device=torch.cuda.current_device(), dtype=torch.uint8, ), ) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """Set TP group.""" self.tp_group = tp_group self.tp_group_initialized = True # This routine is shared across FP8 and FP8_calibration paths so should not actually # assume FP8 execution. def fp8_init(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" if is_fp8_enabled() or is_fp8_calibration(): # FP8 init has already been run and recipe is the same, don't do anything. if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]: return # Set FP8, recipe, and other FP8 metadata self.fp8 = is_fp8_enabled() self.fp8_calibration = is_fp8_calibration() self.fp8_meta["recipe"] = get_fp8_recipe() self.fp8_meta["num_gemms"] = num_gemms self.fp8_meta["fp8_group"] = get_fp8_group() # Set FP8_MAX per tensor according to recipe self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes self.init_fp8_meta_tensors() self.fp8_initialized = True else: # If fp8 isn't enabled, turn off and return. self.fp8_initialized = False return @contextmanager def prepare_forward( self, inp: torch.Tensor, is_first_microbatch: Union[bool, None], num_gemms: int = 1, ) -> None: """Checks and prep for FWD. The context manager is needed because there isn't a way for a module to know if it's the last FP8 module in the forward autocast. It is useful to setup the forward aggregated amax reduction for every module just in case. The autocast exit will pick up the most recent one. """ # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) else: assert inp.is_cuda, "TransformerEngine needs CUDA." if self.tp_size > 1: assert self.tp_group_initialized, "TP group not initialized." self.set_activation_dtype(inp) self.fp8_init(num_gemms=num_gemms) self.set_fp8_weights() update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch if self.fp8 and self.sequence_parallel: assert self.fp8_meta["recipe"].reduce_amax, \ "Amax reduction across tensor parallel group is " \ "necessary when using sequence parallelism with FP8." # Previous iteration was grad_enabled if self.fp8_meta.get("update_amax_and_scale_fwd", False): if self.fp8_meta["recipe"].reduce_amax: copy_amax_from_global_buffer(self.fp8_meta, forward=True) amax_and_scale_update( self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv ) set_amax_buffer_key_deletion(self.fp8_meta, forward=True) else: amax_and_scale_update( self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv ) if self.fp8 and self.training: # Setup for amax reduction if self.fp8_meta["recipe"].reduce_amax: self.fp8_meta["first_module"] = is_first_fp8_module() if self.fp8_meta["first_module"]: # Wait for the prior AMAX reduction to finish amax_reduce_handle_fwd = get_amax_reduce_handle_fwd() if amax_reduce_handle_fwd is not None: amax_reduce_handle_fwd.wait() self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id() set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) else: self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id() self.fp8_meta["autocast_id_fwd_stack"].append( self.fp8_meta["autocast_id_fwd"] ) add_amax_to_global_buffer(self.fp8_meta, forward=True) self.fp8_meta["update_amax_and_scale_fwd"] = True else: self.fp8_meta["update_amax_and_scale_fwd"] = False # Activation recomputation is used and this is the first forward phase. if ( self.fp8 and self.training and is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() ): copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"): yield inp.contiguous() if self.fp8 and in_fp8_activation_recompute_phase(): restore_fp8_meta_tensors(self.fp8_meta) return if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax: set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) reduce_func = partial( global_amax_reduction, self.fp8_meta, self.tp_group, self.tp_size, forward=True ) setup_amax_forward_global_reduce_func(reduce_func) def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled before the GEMM for there to be a guaranteed overlap. From the host side in TE, the comm calls are always launched first, but to ensure that the GEMM isn't scheduled first, the environment variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to force a single channel. """ if self.tp_size == 1: return num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) if num_cuda_work_queues != 1: warnings.warn( "To guarantee overlapping TP and SP collectives with the backward" "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1" ) @staticmethod def grad_output_preprocess( ctx, grad_output: torch.Tensor, row_parallel_mode: bool ) -> Tuple[Union[torch.Tensor, None], ...]: """Utility function for backward. Returns tuple in order (all optional/None based on training precion/recipe): R1: gathered `grad_output` in higher precision. R2: gathered `grad_output` in FP8. R3: R2 transposed. R4: bias gradient on R1. """ grad_output = grad_output.contiguous() grad_output_mat = grad_output.view((-1, grad_output.shape[-1])) gather_grad_output = row_parallel_mode and ctx.sequence_parallel # No-FP8 case: bgrad is fused with wgrad for this case. if not ctx.fp8: if gather_grad_output: if not ctx.ub_split_ag: grad_output_mat, _ = gather_along_first_dim( grad_output_mat, ctx.tp_group ) else: ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True) grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1) return grad_output_mat, None, None, None fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) # FP8 case with non-FP8 wgrad if ( gather_grad_output and ctx.fp8_meta["recipe"].override_linear_precision.wgrad ): assert ( not ctx.ub_split_ag ), "override_linear_precision.wgrad not supported with ub_split_ag" grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather elif gather_grad_output: if ctx.use_bias: grad_bias = grad_output_mat.sum(dim=0) else: grad_bias = None if ctx.ub_split_ag: grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) else: grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) cast_to_fp8( grad_output_mat, ctx.fp8_meta["scaling_bwd"], tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, out=grad_output_c, ) if not ctx.ub_split_ag: grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) else: grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1) grad_output_t = None return grad_output_mat, grad_output_c, grad_output_t, grad_bias # FP8 case without gather: cast, transpose, bgrad fused if ctx.use_bias: grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused( grad_output_mat, ctx.fp8_meta["scaling_bwd"], tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ) else: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: grad_output_c, grad_output_t = fp8_cast_transpose_fused( grad_output_mat, ctx.fp8_meta["scaling_bwd"], tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ) else: grad_output_t = None grad_output_c = cast_to_fp8( grad_output_mat, ctx.fp8_meta["scaling_bwd"], tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ) grad_bias = None return grad_output_mat, grad_output_c, grad_output_t, grad_bias def noop_cat(self, buffer_name: str, pnames: List[str]) -> torch.Tensor: """No-op replacement of `torch.cat`. The buffer and split parameters must occupy the same memory region. If this is not the case, then the split parameters are concatenated and the buffer is overwritten. The parameters' memory is then re-assigned to point to the buffer to avoid subsequent concatenations. """ assert hasattr(self, buffer_name), f"No buffer named {buffer_name}" full_param_buffer = getattr(self, buffer_name) split_size = full_param_buffer.shape[0] // len(pnames) params = [getattr(self, name) for name in pnames] for i, p in enumerate(params): if p.data.data_ptr() != full_param_buffer[i*split_size : (i+1)*split_size].data_ptr(): with torch.no_grad(): setattr(self, buffer_name, torch.cat(params)) for j, pname in enumerate(pnames): full_param_buffer = getattr(self, buffer_name) setattr(self, pname, Parameter(full_param_buffer[j*split_size : (j+1)*split_size])) break return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames]) @abstractmethod def forward(self): """Needs override.""" class _LayerNormLinear(torch.autograd.Function): """LayerNormLinear semi-top level module Calls custom cuda extensions. """ @staticmethod def forward( ctx, inp: torch.Tensor, ln_weight: torch.Tensor, ln_bias: torch.Tensor, weight: torch.Tensor, weight_fp8: Union[torch.Tensor, None], weight_t_fp8: Union[torch.Tensor, None], bias: torch.Tensor, use_bias: bool, eps: float, is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, tp_group: Union[dist_group_type, None], tp_size: int, sequence_parallel: bool, tensor_parallel: bool, activation_dtype: torch.dtype, parallel_mode: Union[str, None], return_layernorm_output: bool, is_grad_enabled: bool, fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_split_ag: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() assert inp.shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) assert ( not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight) ), "Input and weight dimensions are not compatible for FP8 execution." update_fp8_weights = is_first_microbatch is None or is_first_microbatch # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype) # If residual connection is after LN, we need `ln_out` # tensor in higher precision, this comes at the cost # of an extra fp8 cast. if ub_split_ag: tp_world_size = get_distributed_world_size(tp_group) if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: ub_split_ag = False if ub_split_ag: dim_size = list(inputmat.size()) dim_size[0] = dim_size[0] * tp_world_size ub_obj_lnout = get_ub("qkv_fprop") ln_out = ub_obj_lnout.get_ubuf_output(0) if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if not return_layernorm_output: if is_grad_enabled: if not ub_split_ag: ln_out = torch.empty_like(inputmat, dtype=torch.uint8) _, mu, rsigma = layernorm_fwd_fp8( inputmat, ln_weight, ln_bias, eps, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, fwd_ln_sm_margin, zero_centered_gamma, ln_out = ln_out ) else: mu = rsigma = None ln_out = layernorm_fwd_fp8_inf( inputmat, ln_weight, ln_bias, eps, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, zero_centered_gamma, ) else: if is_grad_enabled: ln_out_return, mu, rsigma = tex.layernorm_fwd( inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma ) else: ln_out_return, mu, rsigma = layernorm_fwd_inf( inputmat, ln_weight, ln_bias, eps, zero_centered_gamma ), None, None ln_out = cast_to_fp8( ln_out_return, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, ) else: if is_grad_enabled: if ub_split_ag: _, mu, rsigma = tex.layernorm_fwd_noalloc( inputmat, ln_weight, ln_bias, ln_out, eps, fwd_ln_sm_margin, zero_centered_gamma ) else: ln_out, mu, rsigma = tex.layernorm_fwd( inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma ) else: ln_out, mu, rsigma = layernorm_fwd_inf( inputmat, ln_weight, ln_bias, eps, zero_centered_gamma ), None, None ln_out_return = ln_out # Column Parallel Linear if ub_split_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out = torch.empty_like(ln_out) elif parallel_mode == "column" and sequence_parallel: ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) else: ln_out_total = ln_out if fp8: bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype ) bias = cast_if_needed(bias, bias_dtype) if use_bias else bias if update_fp8_weights: if is_grad_enabled: fp8_cast_transpose_fused( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, cast_out=weight_fp8, transpose_out=weight_t_fp8, ) else: weight_t_fp8 = None weight_fp8 = cast_to_fp8( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward) out = fp8_gemm( weight_fp8, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, ln_out_total, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, activation_dtype, get_workspace(), bias=bias, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, ub=ub_obj_lnout if ub_split_ag else None, extra_output_tensor=ln_out if ub_split_ag else None, ) else: # Cast for native AMP weight = cast_if_needed(weight, activation_dtype) bias = cast_if_needed(bias, activation_dtype) if use_bias else bias if fp8_calibration: # amax of input fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \ torch.amax(ln_out_total).float() # amax of weight fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ torch.amax(weight).float() out, _, _ = gemm( weight, ln_out_total, activation_dtype, get_workspace(), bias=bias, use_bias=use_bias, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, ub=ub_obj_lnout if ub_split_ag else None, extra_output_tensor=ln_out if ub_split_ag else None, ) if is_grad_enabled: ctx.save_for_backward( inputmat, ln_weight, mu, rsigma, weight, weight_t_fp8, ln_out, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ) ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp.shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group ctx.tp_size = tp_size ctx.return_layernorm_output = return_layernorm_output ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.requires_dgrad = inp.requires_grad # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif parallel_mode == "row" and tensor_parallel: out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP out = out.view(-1, *inp.shape[1:-1], out.shape[-1]) if return_layernorm_output: return out, ln_out_return.view_as(inp) return out @staticmethod def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: with _prepare_backward( ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear" ): ( inputmat, ln_weight, mu, rsigma, weight, weight_t_fp8, ln_out, fwd_scale_inverses, ) = ctx.saved_tensors if ctx.ub_bulk_dgrad: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: ctx.ub_bulk_dgrad = False if ctx.ub_bulk_dgrad: dim_size = list(ln_out.size()) dim_size[0] = dim_size[0] * tp_world_size ub_obj_lnout = get_ub("qkv_dgrad") ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) ( grad_output, grad_output_c, grad_output_t, grad_bias, ) = TransformerEngineBaseModule.grad_output_preprocess( ctx, grad_outputs[0], ctx.parallel_mode == "row" ) if ctx.ub_bulk_wgrad: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: ctx.ub_bulk_wgrad = False # Column Parallel Linear # Overlap input AG with dgrad if (not ctx.ub_bulk_dgrad) and ctx.parallel_mode == "column" and ctx.sequence_parallel: ln_out_total, handle = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True ) else: ln_out_total = ln_out if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation dgrad_size = list(grad_output.size()) dgrad_size[1] = weight.size(1) if ctx.ub_bulk_wgrad: # allocate dgrad output ub_obj_dgrad = get_ub("qkv_wgrad") dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device) if ctx.fp8: fp8_dtype_forward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=True ) fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) # DGRAD: Evaluated unconditionally to feed into Linear backward _ = fp8_gemm( weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), out=dgrad, use_split_accumulator=_2X_ACC_DGRAD, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ) else: # DGRAD: Evaluated unconditionally to feed into Linear backward _, _, _ = gemm( weight, grad_output, ctx.activation_dtype, get_workspace(), out=dgrad, layout="NN", grad=True, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ) if ctx.ub_bulk_dgrad: ln_out_total = ub_obj_lnout.get_ubuf_output(1) # Overlap dgrad-RS/AR with wgrad if ctx.parallel_mode == "column" and ctx.sequence_parallel: if not ctx.ub_bulk_dgrad: handle.wait() if not ctx.ub_bulk_wgrad: dgrad, handle = reduce_scatter_along_first_dim( dgrad, ctx.tp_group, async_op=True ) elif ctx.parallel_mode == "column" and ctx.tensor_parallel: dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) if weight.requires_grad: if ctx.fp8: # WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) wgrad = fp8_gemm( ln_out_total_t, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ) else: ln_out_total_c = cast_from_fp8( ln_out_total, ctx.fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, TE_DType[ctx.activation_dtype], ) wgrad, _, _ = gemm( ln_out_total_c, grad_output, ctx.activation_dtype, get_workspace(), layout="NT", grad=True, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ) else: # WGRAD wgrad, grad_bias, _ = gemm( ln_out_total, grad_output, ctx.activation_dtype, get_workspace(), layout="NT", grad=True, use_bias=ctx.use_bias, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ) if ctx.ub_bulk_wgrad: dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output # Column Parallel Linear elif ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: handle.wait() # LayerNorm gradient d_ln_out = dgrad.view(inputmat.shape) # Residual gradient if ctx.return_layernorm_output: d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) dxmat, dgamma, dbeta = tex.layernorm_bwd( d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ) if not ctx.use_bias: grad_bias = None return ( dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, dbeta, wgrad if weight.requires_grad else None, None, None, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) class LayerNormLinear(TransformerEngineBaseModule): r""" Applies layer normalization followed by linear transformation to the incoming data. Parameters ---------- in_features : int size of each input sample. out_features : int size of each output sample. eps : float, default = 1e-5 a value added to the denominator of layer normalization for numerical stability. bias : bool, default = `True` if set to `False`, the layer will not learn an additive bias. init_method : Callable, default = `None` used for initializing weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. return_layernorm_output : bool, default = `False` if set to `True`, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm. parameters_split : Tuple[str, ...], default = None if a tuple of strings is provided, the weight and bias parameters of the module are exposed as `N` separate `torch.nn.parameter.Parameter`s each, split along the first dimension, where `N` is the length of the argument and the strings contained are the names of the split parameters. zero_centered_gamma : bool, default = 'False' if set to 'True', gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta Parallelism parameters ---------------------- sequence_parallel : bool, default = `False` if set to `True`, uses sequence parallelism. tp_group : ProcessGroup, default = `None` tensor parallel process group. tp_size : int, default = 1 used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. parallel_mode : {None, 'Column', 'Row'}, default = `None` used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to `None`, no communication is performed. skip_weight_param_allocation: bool, default = `False` if set to `True`, weight parameter is not allocated and must be passed as a keyword argument `weight` during the forward pass. Optimization parameters ----------------------- fuse_wgrad_accumulation : bool, default = 'False' if set to `True`, enables fusing of creation and accumulation of the weight gradient. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation :math:`y = xA^T`. This is useful when the bias addition can be fused to subsequent operations. params_dtype : torch.dtype, default = `torch.float32` it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. """ def __init__( self, in_features: int, out_features: int, eps: float = 1e-5, sequence_parallel: bool = False, fuse_wgrad_accumulation: bool = False, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, init_method: Optional[Callable] = None, bias: bool = True, return_bias: bool = False, params_dtype: torch.dtype = torch.float32, parallel_mode: Optional[str] = None, return_layernorm_output: bool = False, skip_weight_param_allocation: bool = False, parameters_split: Optional[Tuple[str, ...]] = None, zero_centered_gamma: bool = False, ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_split_ag: bool = False, ) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias self.return_layernorm_output = return_layernorm_output self.parameters_split = parameters_split self.zero_centered_gamma = zero_centered_gamma self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_ag = ub_split_ag if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag: assert ( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." if tp_group is None: self.tp_size = tp_size if tp_size == 1: self.set_tensor_parallel_group(tp_group) else: self.tp_size = get_distributed_world_size(tp_group) self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes ), f"parallel_mode {parallel_mode} not supported" if self.parallel_mode == "column": self.out_features = divide(self.out_features, self.tp_size) elif self.parallel_mode == "row": self.in_features = divide(self.in_features, self.tp_size) if init_method is None: init_method = get_default_init_method() self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.eps = eps self.layer_norm_weight = Parameter( torch.empty( in_features, device=torch.cuda.current_device(), dtype=params_dtype, ) ) self.layer_norm_bias = Parameter( torch.empty( in_features, device=torch.cuda.current_device(), dtype=params_dtype, ) ) setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) self.reset_layer_norm_parameters() if not skip_weight_param_allocation: self.register_buffer("weight_tensor", torch.empty( self.out_features, self.in_features, device=torch.cuda.current_device(), dtype=params_dtype), persistent=False) initialize_affine_weight_gpu( self.weight_tensor, init_method, get_rng_state_tracker, partition_dim=1 if self.parallel_mode == "row" else 0, stride=1, ) if self.use_bias: self.register_buffer("bias_tensor", torch.empty( self.out_features, device=torch.cuda.current_device(), dtype=params_dtype), persistent=False) else: self.register_buffer( "bias_tensor", torch.Tensor().type(params_dtype), persistent=False ) with torch.no_grad(): self.bias_tensor.zero_() if parameters_split is None: parameters_split = ("",) assert ( self.out_features % len(parameters_split) == 0 ), f"Weight and bias params cannot be split into {len(parameters_split)} parts" split_size = self.out_features // len(parameters_split) self.weight_names = [] self.bias_names = [] for i, pname in enumerate(parameters_split): wname = pname + "weight" bname = pname + "bias" self.register_parameter( wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size]) ) set_tensor_model_parallel_attributes( tensor=getattr(self, wname), is_parallel=True, dim=1 if parallel_mode == "row" else 0, stride=1, ) if self.use_bias: self.register_parameter( bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) ) else: self.register_buffer(bname, torch.Tensor().type(params_dtype), persistent=False) if parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) self.weight_names.append(wname) self.bias_names.append(bname) self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.parallel_mode == "row" and self.apply_bias: self.gemm_bias_unfused_add = True else: self.gemm_bias_unfused_add = False # These many SMs are subtracted from the total SM count when calling forward # and backward LayerNorm C APIs. These envvars can be used to prevent the LN # kernels from using all SMs in the device. This is useful for cases such as # communication overlap with LN. self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) def reset_layer_norm_parameters(self) -> None: """Init LN params""" if not self.zero_centered_gamma: init.ones_(self.layer_norm_weight) else: init.zeros_(self.layer_norm_weight) init.zeros_(self.layer_norm_bias) def forward( self, inp: torch.Tensor, weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a linear transformation. Parameters ---------- inp : torch.Tensor Input tensor. weight : torch.Tensor, default = None An optional weight tensor for the module. This argument is compulsory if module is initialized with `skip_weight_param_allocation=True` bias : torch.Tensor, default = None An optional bias tensor for the module. This argument is compulsory if module is initialized with `skip_weight_param_allocation=True` and one of `use_bias` or `return_bias` is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations: * during FP8 training, it allows caching of the FP8 versions of the weights * it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced) """ with self.prepare_forward(inp, is_first_microbatch) as inp: bias_tensor = ( bias if bias is not None else self.bias if self.parameters_split is None else self.bias_tensor if not torch.is_grad_enabled() else self.noop_cat("bias_tensor", self.bias_names) ) weight_tensor = ( weight if weight is not None else self.weight if self.parameters_split is None else self.weight_tensor if not torch.is_grad_enabled() else self.noop_cat("weight_tensor", self.weight_names) ) if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply args = [] else: fwd_fn = _LayerNormLinear.forward args = [None] args += ( inp, self.layer_norm_weight, self.layer_norm_bias, weight_tensor, self.weight1_fp8 if self.fp8 else None, self.weight1_t_fp8 if self.fp8 else None, bias_tensor, self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, self.fp8, self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, self.tp_group, self.tp_size, self.sequence_parallel, self.tp_size > 1, self.activation_dtype, self.parallel_mode, self.return_layernorm_output, torch.is_grad_enabled(), self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_split_ag, ) out = fwd_fn(*args) if self.return_layernorm_output: out, ln_out = out if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) if self.return_bias: if self.return_layernorm_output: return out, cast_if_needed(bias_tensor, self.activation_dtype), ln_out return out, cast_if_needed(bias_tensor, self.activation_dtype) if self.return_layernorm_output: return out, ln_out return out class _Linear(torch.autograd.Function): """Linear semi-top level module Calls custom cuda extensions. """ @staticmethod def forward( ctx, weight: torch.Tensor, weight_fp8: Union[torch.Tensor, None], weight_t_fp8: Union[torch.Tensor, None], inp: torch.Tensor, bias: torch.Tensor, use_bias: bool, is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, tp_group: Union[dist_group_type, None], tp_size: int, sequence_parallel: bool, tensor_parallel: bool, activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, ub_split_rs: bool, ub_split_ag: bool, ) -> torch.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] assert inp.shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) assert ( not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight) ), "Input and weight dimensions are not compatible for FP8 execution." update_fp8_weights = is_first_microbatch is None or is_first_microbatch if ub_split_rs: tp_world_size = get_distributed_world_size(tp_group) if tp_world_size == 1: ub_split_rs = False # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) inputmat_no_fp8 = inputmat if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if not fp8_meta["recipe"].override_linear_precision.wgrad: if is_grad_enabled: inputmat, inputmat_t = fp8_cast_transpose_fused( inputmat, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, ) else: inputmat = cast_to_fp8( inputmat, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, ) else: inputmat, inputmat_t = cast_to_fp8( inputmat, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, ), None # Column Parallel Linear if parallel_mode == "column" and sequence_parallel: inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: inputmat_total = inputmat if fp8: bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype ) bias = cast_if_needed(bias, bias_dtype) if use_bias else bias if update_fp8_weights: if is_grad_enabled: fp8_cast_transpose_fused( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, cast_out=weight_fp8, transpose_out=weight_t_fp8, ) else: weight_t_fp8 = None weight_fp8 = cast_to_fp8( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, ) if ub_split_rs: ub_obj_projout = get_ub("proj_fprop") out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) else: dim_size = list(inputmat_total.size()) dim_size[1] = weight.size(0) out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) _ = fp8_gemm( weight_fp8, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, inputmat, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, activation_dtype, get_workspace(), bias=bias, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, out=out, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, ub=ub_obj_projout if ub_split_rs else None, extra_output_tensor=rs_out if ub_split_rs else None, ) else: # Cast for native AMP weight = cast_if_needed(weight, activation_dtype) bias = cast_if_needed(bias, activation_dtype) if use_bias else bias if fp8_calibration: # amax of input fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \ torch.amax(inputmat_total).float() # amax of weight fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ torch.amax(weight).float() if ub_split_rs: ub_obj_projout = get_ub("proj_fprop") out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) else: dim_size = list(inputmat_total.size()) dim_size[1] = weight.size(0) out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) _, _, _ = gemm( weight, inputmat_total, activation_dtype, get_workspace(), bias=bias, use_bias=use_bias, out=out, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, ub=ub_obj_projout if ub_split_rs else None, extra_output_tensor=rs_out if ub_split_rs else None, ) if is_grad_enabled: fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad ctx.save_for_backward( inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None, inputmat_t if weight.requires_grad and fp8_wgrad else None, weight, weight_t_fp8 if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ) ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp.shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group ctx.ub_split_ag = ub_split_ag ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad # Row Parallel Linear if ub_split_rs: out = rs_out elif parallel_mode == "row" and sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif parallel_mode == "row" and tensor_parallel: out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @staticmethod def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: with _prepare_backward( ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear" ): ( inputmat, inputmat_t, weight, weight_t_fp8, fwd_scale_inverses, ) = ctx.saved_tensors if ctx.ub_split_ag: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: ctx.ub_split_ag = False if ctx.ub_split_ag: dim_size = list(grad_output.size()) dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub("proj_dgrad") ( grad_output, grad_output_c, grad_output_t, grad_bias, ) = TransformerEngineBaseModule.grad_output_preprocess( ctx, grad_output, ctx.parallel_mode == "row" ) # Column Parallel Linear # Overlap input AG with dgrad if ctx.parallel_mode == "column" and ctx.sequence_parallel: if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: inputmat_t_total, handle = gather_along_last_dim( inputmat_t, ctx.tp_group, async_op=ctx.requires_dgrad ) else: inputmat_total, handle = gather_along_first_dim( inputmat, ctx.tp_group, async_op=ctx.requires_dgrad ) else: inputmat_t_total = inputmat_t inputmat_total = inputmat if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation if ctx.fp8: fp8_dtype_forward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=True ) fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) if ctx.requires_dgrad: if ctx.fp8: dgrad = fp8_gemm( weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, ) else: dgrad, _, _ = gemm( weight, grad_output, ctx.activation_dtype, get_workspace(), layout="NN", grad=True, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, ) # Overlap dgrad-RS/AR with wgrad if ctx.parallel_mode == "column" and ctx.sequence_parallel: handle.wait() dgrad, handle = reduce_scatter_along_first_dim( dgrad, ctx.tp_group, async_op=True ) elif ctx.parallel_mode == "column" and ctx.tensor_parallel: dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) if weight.requires_grad: if ctx.fp8: # WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if ctx.ub_split_ag: grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) wgrad = fp8_gemm( inputmat_t_total, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, ) else: wgrad, _, _ = gemm( inputmat_total, grad_output, ctx.activation_dtype, get_workspace(), layout="NT", grad=True, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) else: # WGRAD wgrad, grad_bias, _ = gemm( inputmat_total, grad_output, ctx.activation_dtype, get_workspace(), layout="NT", grad=True, use_bias=ctx.use_bias, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) # Column Parallel Linear if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: handle.wait() if not ctx.use_bias: grad_bias = None return ( wgrad if weight.requires_grad else None, None, None, dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) class Linear(TransformerEngineBaseModule): """ Applies a linear transformation to the incoming data :math:`y = xA^T + b` On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`. Parameters ---------- in_features : int size of each input sample. out_features : int size of each output sample. bias : bool, default = `True` if set to `False`, the layer will not learn an additive bias. init_method : Callable, default = `None` used for initializing weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. parameters_split : Tuple[str, ...], default = None if a tuple of strings is provided, the weight and bias parameters of the module are exposed as `N` separate `torch.nn.parameter.Parameter`s each, split along the first dimension, where `N` is the length of the argument and the strings contained are the names of the split parameters. Parallelism parameters ---------------------- sequence_parallel : bool, default = `False` if set to `True`, uses sequence parallelism. tp_group : ProcessGroup, default = `None` tensor parallel process group. tp_size : int, default = 1 used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. parallel_mode : {None, 'Column', 'Row'}, default = `None` used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to `None`, no communication is performed. skip_weight_param_allocation: bool, default = `False` if set to `True`, weight parameter is not allocated and must be passed as a keyword argument `weight` during the forward pass. Optimization parameters ----------------------- fuse_wgrad_accumulation : bool, default = 'False' if set to `True`, enables fusing of creation and accumulation of the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct size to accumulate gradients in. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the output of the linear transformation :math:`y = xA^T`. This is useful when the bias addition can be fused to subsequent operations. params_dtype : torch.dtype, default = `torch.float32` it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. """ def __init__( self, in_features: int, out_features: int, sequence_parallel: bool = False, fuse_wgrad_accumulation: bool = False, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, get_rng_state_tracker: Optional[Callable] = None, init_method: Optional[Callable] = None, bias: bool = True, return_bias: bool = False, params_dtype: torch.dtype = torch.float32, parallel_mode: Optional[str] = None, skip_weight_param_allocation: bool = False, parameters_split: Optional[Tuple[str, ...]] = None, ub_split_rs: bool = False, ub_split_ag: bool = False, ) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias self.parameters_split = parameters_split self.ub_split_rs = ub_split_rs self.ub_split_ag = ub_split_ag if ub_split_rs or ub_split_ag: assert ( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." if tp_group is None: self.tp_size = tp_size if tp_size == 1: self.set_tensor_parallel_group(tp_group) else: self.tp_size = get_distributed_world_size(tp_group) self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes ), f"parallel_mode {parallel_mode} not supported" if self.parallel_mode == "column": self.out_features = divide(self.out_features, self.tp_size) elif self.parallel_mode == "row": self.in_features = divide(self.in_features, self.tp_size) if init_method is None: init_method = get_default_init_method() self.sequence_parallel = (self.tp_size > 1) and sequence_parallel if not skip_weight_param_allocation: self.register_buffer("weight_tensor", torch.empty( self.out_features, self.in_features, device=torch.cuda.current_device(), dtype=params_dtype), persistent=False) initialize_affine_weight_gpu( self.weight_tensor, init_method, get_rng_state_tracker, partition_dim=1 if self.parallel_mode == "row" else 0, stride=1, ) if self.use_bias: self.register_buffer("bias_tensor", torch.empty( self.out_features, device=torch.cuda.current_device(), dtype=params_dtype), persistent=False) else: self.register_buffer( "bias_tensor", torch.Tensor().type(params_dtype), persistent=False ) with torch.no_grad(): self.bias_tensor.zero_() if parameters_split is None: parameters_split = ("",) assert ( self.out_features % len(parameters_split) == 0 ), f"Weight and bias params cannot be split into {len(parameters_split)} parts" split_size = self.out_features // len(parameters_split) self.weight_names = [] self.bias_names = [] for i, pname in enumerate(parameters_split): wname = pname + "weight" bname = pname + "bias" self.register_parameter( wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size]) ) set_tensor_model_parallel_attributes( tensor=getattr(self, wname), is_parallel=True, dim=1 if parallel_mode == "row" else 0, stride=1, ) if self.use_bias: self.register_parameter( bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) ) else: self.register_buffer(bname, torch.Tensor().type(params_dtype), persistent=False) if parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) self.weight_names.append(wname) self.bias_names.append(bname) self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.parallel_mode == "row" and self.apply_bias: self.gemm_bias_unfused_add = True else: self.gemm_bias_unfused_add = False def forward( self, inp: torch.Tensor, weight: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. Parameters ---------- inp : torch.Tensor Input tensor. weight : torch.Tensor, default = None An optional weight tensor for the module. This argument is compulsory if module is initialized with `skip_weight_param_allocation=True` bias : torch.Tensor, default = None An optional bias tensor for the module. This argument is compulsory if module is initialized with `skip_weight_param_allocation=True` and one of `use_bias` or `return_bias` is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations: * during FP8 training, it allows caching of the FP8 versions of the weights * it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced) """ with self.prepare_forward(inp, is_first_microbatch) as inp: bias_tensor = ( bias if bias is not None else self.bias if self.parameters_split is None else self.bias_tensor if not torch.is_grad_enabled() else self.noop_cat("bias_tensor", self.bias_names) ) weight_tensor = ( weight if weight is not None else self.weight if self.parameters_split is None else self.weight_tensor if not torch.is_grad_enabled() else self.noop_cat("weight_tensor", self.weight_names) ) if torch.is_grad_enabled(): linear_fn = _Linear.apply args = [] else: linear_fn = _Linear.forward args = [None] args += ( weight_tensor, self.weight1_fp8 if self.fp8 else None, self.weight1_t_fp8 if self.fp8 else None, inp, bias_tensor, self.apply_bias and not self.gemm_bias_unfused_add, is_first_microbatch, self.fp8, self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, self.tp_group, self.tp_size, self.sequence_parallel, self.tp_size > 1, self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), self.ub_split_rs, self.ub_split_ag, ) out = linear_fn(*args) if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) if self.return_bias: return out, cast_if_needed(bias_tensor, self.activation_dtype) return out class _LayerNormMLP(torch.autograd.Function): """LayerNormMLP semi-top level module Calls custom cuda extensions. """ @staticmethod def forward( ctx, inp: torch.Tensor, ln_weight: torch.Tensor, ln_bias: torch.Tensor, fc1_weight: torch.Tensor, fc1_weight_fp8: Union[torch.Tensor, None], fc1_weight_t_fp8: Union[torch.Tensor, None], fc1_bias: torch.Tensor, use_fc1_bias: bool, fc2_weight: torch.Tensor, fc2_weight_fp8: Union[torch.Tensor, None], fc2_weight_t_fp8: Union[torch.Tensor, None], fc2_bias: torch.Tensor, use_fc2_bias: bool, eps: float, is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, tp_group: Union[dist_group_type, None], tp_size: int, sequence_parallel: bool, tensor_parallel: bool, activation_dtype: torch.dtype, return_layernorm_output: bool, bias_gelu_nvfusion: bool, set_parallel_mode: bool, is_grad_enabled: bool, fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_split_rs: bool, ub_split_ag: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() assert inp.shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) assert ( not fp8 or check_dim_for_fp8_forward_exec(inputmat, fc1_weight, fc2_weight) ), "Input and weight dimensions are not compatible for FP8 execution." update_fp8_weights = is_first_microbatch is None or is_first_microbatch # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype) if ub_split_ag: tp_world_size = get_distributed_world_size(tp_group) if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: ub_split_ag = False if ub_split_ag: ub_obj_lnout = get_ub("fc1_fprop") ln_out = ub_obj_lnout.get_ubuf_output(0) if ub_split_rs: tp_world_size = get_distributed_world_size(tp_group) if tp_world_size == 1: ub_split_rs = False # If residual connection is after LN, we need `ln_out` # tensor in higher precision, this comes at the cost # of an extra fp8 cast. if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if not return_layernorm_output: if is_grad_enabled: if not ub_split_ag: ln_out = torch.empty_like(inputmat, dtype=torch.uint8) _, mu, rsigma = layernorm_fwd_fp8( inputmat, ln_weight, ln_bias, eps, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, fwd_ln_sm_margin, zero_centered_gamma, ln_out = ln_out, ) else: ln_out = layernorm_fwd_fp8_inf( inputmat, ln_weight, ln_bias, eps, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, zero_centered_gamma, ) else: ln_out_return, mu, rsigma = tex.layernorm_fwd( inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma ) ln_out = cast_to_fp8( ln_out_return, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, ) else: if is_grad_enabled: if ub_split_ag: _, mu, rsigma = tex.layernorm_fwd_noalloc( inputmat, ln_weight, ln_bias, ln_out, eps, fwd_ln_sm_margin, zero_centered_gamma ) else: ln_out, mu, rsigma = tex.layernorm_fwd( inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma ) else: ln_out, mu, rsigma = layernorm_fwd_inf( inputmat, ln_weight, ln_bias, eps, zero_centered_gamma ), None, None ln_out_return = ln_out # Column Parallel Linear if ub_split_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out = torch.empty_like(ln_out) elif set_parallel_mode and sequence_parallel: ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) else: ln_out_total = ln_out if fp8: bias_dtype = ( torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype ) fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias if update_fp8_weights: if is_grad_enabled: fp8_cast_transpose_fused( fc1_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, cast_out=fc1_weight_fp8, transpose_out=fc1_weight_t_fp8, ) fp8_cast_transpose_fused( fc2_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, cast_out=fc2_weight_fp8, transpose_out=fc2_weight_t_fp8, ) else: fc1_weight_t_fp8 = None fc1_weight_fp8 = cast_to_fp8( fc1_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, ) fc2_weight_t_fp8 = None fc2_weight_fp8 = cast_to_fp8( fc2_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, ) fc1_out = fp8_gemm( fc1_weight_fp8, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, ln_out_total, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, activation_dtype, get_workspace(), bias=fc1_bias, use_bias=use_fc1_bias, use_split_accumulator=_2X_ACC_FPROP, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, ub=ub_obj_lnout if ub_split_ag else None, extra_output_tensor=ln_out if ub_split_ag else None, ) gelu_out = fp8_gelu( fc1_out, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM2_INPUT, fp8_dtype_forward, ) if ub_split_rs: ub_obj_fc2out = get_ub("fc2_fprop") fc2_out = ub_obj_fc2out.get_ubuf_output(1) dim_size = list(gelu_out.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) else: dim_size = list(gelu_out.size()) dim_size[1] = fc2_weight.size(0) fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) _ = fp8_gemm( fc2_weight_fp8, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, gelu_out, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM2_INPUT, fp8_dtype_forward, activation_dtype, get_workspace(), bias=fc2_bias, use_bias=use_fc2_bias, use_split_accumulator=_2X_ACC_FPROP, out=fc2_out, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, ub=ub_obj_fc2out if ub_split_rs else None, extra_output_tensor=rs_out if ub_split_rs else None, ) else: # Cast for native AMP fc1_weight = cast_if_needed(fc1_weight, activation_dtype) fc2_weight = cast_if_needed(fc2_weight, activation_dtype) fc1_bias = ( cast_if_needed(fc1_bias, activation_dtype) if use_fc1_bias else fc1_bias ) fc2_bias = ( cast_if_needed(fc2_bias, activation_dtype) if use_fc2_bias else fc2_bias ) if fp8_calibration: # amax of fc1 input fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \ torch.amax(ln_out_total).float() # amax of fc1 weight fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ torch.amax(fc1_weight).float() fc1_outputs = gemm( fc1_weight, ln_out_total, activation_dtype, get_workspace(), bias=fc1_bias, use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, gelu=not bias_gelu_nvfusion, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, ub=ub_obj_lnout if ub_split_ag else None, extra_output_tensor=ln_out if ub_split_ag else None, ) if bias_gelu_nvfusion: fc1_out, _, _ = fc1_outputs gelu_out = bias_gelu_fused(fc1_out, fc1_bias) else: gelu_out, _, fc1_out = fc1_outputs if fp8_calibration: # amax of fc2 input fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = \ torch.amax(gelu_out).float() # amax of fc2 weight fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \ torch.amax(fc2_weight).float() if ub_split_rs: ub_obj_fc2out = get_ub("fc2_fprop") fc2_out = ub_obj_fc2out.get_ubuf_output(1) dim_size = list(gelu_out.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) else: dim_size = list(gelu_out.size()) dim_size[1] = fc2_weight.size(0) fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) _, _, _ = gemm( fc2_weight, gelu_out, activation_dtype, get_workspace(), bias=fc2_bias, use_bias=use_fc2_bias, out=fc2_out, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, ub=ub_obj_fc2out if ub_split_rs else None, extra_output_tensor=rs_out if ub_split_rs else None, ) if is_grad_enabled: ctx.save_for_backward( inputmat, ln_weight, mu, rsigma, ln_out, fc1_out, gelu_out, fc1_weight, fc1_weight_t_fp8, fc2_weight, fc2_weight_t_fp8, fc1_bias, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ) ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.is_first_microbatch = is_first_microbatch ctx.use_fc1_bias = use_fc1_bias ctx.use_fc2_bias = use_fc2_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp.shape ctx.tp_group = tp_group ctx.tp_size = tp_size ctx.bias_gelu_nvfusion = bias_gelu_nvfusion ctx.return_layernorm_output = return_layernorm_output ctx.set_parallel_mode = set_parallel_mode ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_split_ag = ub_split_ag ctx.requires_dgrad = inp.requires_grad # Row Parallel Linear if ub_split_rs: fc2_out = rs_out elif set_parallel_mode and sequence_parallel: fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) elif set_parallel_mode and tensor_parallel: fc2_out, _ = allreduce(fc2_out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1]) if return_layernorm_output: return fc2_out, ln_out_return.view_as(inp) return fc2_out @staticmethod def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: with _prepare_backward( ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP" ): ( inputmat, ln_weight, mu, rsigma, ln_out, fc1_out, gelu_out, fc1_weight, fc1_weight_t_fp8, fc2_weight, fc2_weight_t_fp8, fc1_bias, fwd_scale_inverses, ) = ctx.saved_tensors if ctx.ub_bulk_dgrad: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: ctx.ub_bulk_dgrad = False if ctx.ub_bulk_dgrad: dim_size = list(ln_out.size()) dim_size[0] = dim_size[0] * tp_world_size ub_obj_lnout = get_ub("fc1_dgrad") ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) if ctx.ub_split_ag: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: ctx.ub_split_ag = False if ctx.ub_split_ag: dim_size = list(grad_outputs[0].size()) dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub("fc2_dgrad") ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ( grad_output, grad_output_c, grad_output_t, fc2_bias_grad, ) = TransformerEngineBaseModule.grad_output_preprocess( ctx, grad_outputs[0], True ) if ctx.ub_bulk_wgrad: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: ctx.ub_bulk_wgrad = False # Column Parallel Linear # Overlap input AG with dgrad if (not ctx.ub_bulk_dgrad) and ctx.set_parallel_mode and ctx.sequence_parallel: ln_out_total, handle = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True ) else: ln_out_total = ln_out if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation if ctx.fp8: fp8_dtype_forward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=True ) fp8_dtype_backward = get_fp8_te_dtype( ctx.fp8_meta["recipe"], fprop_tensor=False ) # FC2 DGRAD; Unconditional fc2_dgrad = fp8_gemm( fc2_weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, grad_output_c, ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, ) if ctx.ub_split_ag: grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) # FC2 WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if fc2_weight.requires_grad: gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward) fc2_wgrad = fp8_gemm( gelu_out_t, fwd_scale_inverses, tex.FP8FwdTensors.GEMM2_INPUT, fp8_dtype_forward, grad_output_t, ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT1, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), accumulate=accumulate_wgrad_into_param_main_grad, out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, ) fc1_bias_grad, dgelu, dgelu_t = fp8_cast_transpose_bgrad_dgelu_fused( fc2_dgrad, fc1_out, ctx.fp8_meta["scaling_bwd"], tex.FP8BwdTensors.GRAD_OUTPUT2, fp8_dtype_backward, ) else: if fc2_weight.requires_grad: gelu_out_c = cast_from_fp8( gelu_out, ctx.fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM2_INPUT, fp8_dtype_forward, TE_DType[ctx.activation_dtype], ) fc2_wgrad, _, _ = gemm( gelu_out_c, grad_output, ctx.activation_dtype, get_workspace(), layout="NT", grad=True, use_bias=False, accumulate=accumulate_wgrad_into_param_main_grad, out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused( fc2_dgrad, fc1_out, fc1_bias ) dgelu = cast_to_fp8( dgelu_no_fp8, ctx.fp8_meta["scaling_bwd"], tex.FP8BwdTensors.GRAD_OUTPUT2, fp8_dtype_backward, ) dgelu_t = None fc1_dgrad_size = list(dgelu.size()) fc1_dgrad_size[1] = fc1_weight.size(1) if ctx.ub_bulk_wgrad: # allocate dgrad output ub_obj_dgrad = get_ub("fc1_wgrad") fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: fc1_dgrad = torch.empty( fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device ) # FC1 DGRAD: Unconditional _ = fp8_gemm( fc1_weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, dgelu, ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT2, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), out=fc1_dgrad, use_split_accumulator=_2X_ACC_DGRAD, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ) else: # FC2 DGRAD; Unconditional fc2_dgrad, _, _ = gemm( fc2_weight, grad_output, ctx.activation_dtype, get_workspace(), layout="NN", gelu=not ctx.bias_gelu_nvfusion, grad=True, gelu_input=fc1_out, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, ) # FC2 WGRAD if fc2_weight.requires_grad: fc2_wgrad, fc2_bias_grad, _ = gemm( gelu_out, grad_output, ctx.activation_dtype, get_workspace(), layout="NT", grad=True, use_bias=ctx.use_fc2_bias, accumulate=accumulate_wgrad_into_param_main_grad, out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) if ctx.bias_gelu_nvfusion: fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias) else: dgelu = fc2_dgrad fc1_dgrad_size = list(dgelu.size()) fc1_dgrad_size[1] = fc1_weight.size(1) if ctx.ub_bulk_wgrad: # allocate dgrad output ub_obj_dgrad = get_ub("fc1_wgrad") fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output else: fc1_dgrad = torch.empty( fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device ) # FC1 DGRAD: Unconditional _, _, _ = gemm( fc1_weight, dgelu, ctx.activation_dtype, get_workspace(), out=fc1_dgrad, layout="NN", grad=True, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None, ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None ) if ctx.ub_bulk_dgrad: ln_out_total = ub_obj_lnout.get_ubuf_output(1) # Overlap dgrad-RS/AR with wgrad if ctx.set_parallel_mode and ctx.sequence_parallel: if not ctx.ub_bulk_dgrad: handle.wait() if not ctx.ub_bulk_wgrad: fc1_dgrad, handle = reduce_scatter_along_first_dim( fc1_dgrad, ctx.tp_group, async_op=True ) elif ctx.set_parallel_mode and ctx.tensor_parallel: fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) if fc1_weight.requires_grad: if ctx.fp8: # FC1 WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) fc1_wgrad = fp8_gemm( ln_out_total_t, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, dgelu_t, ctx.fp8_meta["scaling_bwd"].scale_inv, tex.FP8BwdTensors.GRAD_OUTPUT2, fp8_dtype_backward, ctx.activation_dtype, get_workspace(), accumulate=accumulate_wgrad_into_param_main_grad, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ) else: ln_out_total_c = cast_from_fp8( ln_out_total, ctx.fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, TE_DType[ctx.activation_dtype], ) fc1_wgrad, _, _ = gemm( ln_out_total_c, dgelu_no_fp8, ctx.activation_dtype, get_workspace(), layout="NT", grad=True, accumulate=accumulate_wgrad_into_param_main_grad, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ) else: # FC1 WGRAD fc1_wgrad_outputs = gemm( ln_out_total, dgelu, ctx.activation_dtype, get_workspace(), layout="NT", grad=True, use_bias=not ctx.bias_gelu_nvfusion, accumulate=accumulate_wgrad_into_param_main_grad, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ) if ctx.bias_gelu_nvfusion: fc1_wgrad, _, _ = fc1_wgrad_outputs else: fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs # Column Parallel Linear if ctx.ub_bulk_wgrad: fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output elif ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None: handle.wait() # LayerNorm gradient d_ln_out = fc1_dgrad.view(inputmat.shape) # Residual gradient if ctx.return_layernorm_output: d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) dxmat, dgamma, dbeta = tex.layernorm_bwd( d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ) return ( dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, dbeta, fc1_wgrad if fc1_weight.requires_grad else None, None, None, fc1_bias_grad if ctx.use_fc1_bias else None, None, fc2_wgrad if fc2_weight.requires_grad else None, None, None, fc2_bias_grad if ctx.use_fc2_bias else None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, ) class LayerNormMLP(TransformerEngineBaseModule): r""" Applies layer normalization on the input followed by the MLP module, consisting of 2 successive linear transformations, separated by the GeLU activation. Parameters ---------- hidden_size : int size of each input sample. ffn_hidden_size : int intermediate size to which input samples are projected. eps : float, default = 1e-5 a value added to the denominator of layer normalization for numerical stability. bias : bool, default = `True` if set to `False`, the FC1 and FC2 layers will not learn an additive bias. init_method : Callable, default = `None` used for initializing FC1 weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. output_layer_init_method : Callable, default = `None` used for initializing FC2 weights in the following way: `output_layer_init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. return_layernorm_output : bool, default = `False` if set to `True`, output of layernorm is returned from the forward together with the output of the linear transformation. Example use case: residual connection for transformer module is taken post layernorm. zero_centered_gamma : bool, default = 'False' if set to 'True', gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta Parallelism parameters ---------------------- set_parallel_mode : bool, default = `False` if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row Parallel as described `here `_. sequence_parallel : bool, default = `False` if set to `True`, uses sequence parallelism. tp_group : ProcessGroup, default = `None` tensor parallel process group. tp_size : int, default = 1 used as TP (tensor parallel) world size when TP groups are not formed during initialization. In this case, users must call the `set_tensor_parallel_group(tp_group)` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. Optimization parameters ----------------------- fuse_wgrad_accumulation : bool, default = 'False' if set to `True`, enables fusing of creation and accumulation of the weight gradient. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias for FC2, but instead return the bias value during the forward pass together with the output of the linear transformation :math:`y = xA^T`. This is useful when the bias addition can be fused to subsequent operations. params_dtype : torch.dtype, default = `torch.float32` it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. seq_length: int sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase. micro_batch_size: int batch size per training step. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for forward propogation and activation recompute phase. """ def __init__( self, hidden_size: int, ffn_hidden_size: int, eps: float = 1e-5, sequence_parallel: bool = False, return_bias: bool = False, get_rng_state_tracker: Optional[Callable] = None, tp_group: Optional[dist_group_type] = None, tp_size: int = 1, init_method: Optional[Callable] = None, bias: bool = True, output_layer_init_method: Optional[Callable] = None, fuse_wgrad_accumulation: bool = False, params_dtype: torch.dtype = torch.float32, return_layernorm_output: bool = False, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, set_parallel_mode: bool = False, zero_centered_gamma: bool = False, ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_split_rs: bool = False, ub_split_ag: bool = False, ) -> None: super().__init__() self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias self.return_layernorm_output = return_layernorm_output self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_rs = ub_split_rs self.ub_split_ag = ub_split_ag if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_rs or ub_split_ag: assert ( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." if tp_group is None: self.tp_size = tp_size if tp_size == 1: self.set_tensor_parallel_group(tp_group) else: self.tp_size = get_distributed_world_size(tp_group) self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() if init_method is None: init_method = get_default_init_method() if output_layer_init_method is None: output_layer_init_method = get_default_init_method() self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.size_per_partition = divide(ffn_hidden_size, self.tp_size) # LN init self.eps = eps self.layer_norm_weight = Parameter( torch.empty( hidden_size, device=torch.cuda.current_device(), dtype=params_dtype, ) ) self.layer_norm_bias = Parameter( torch.empty( hidden_size, device=torch.cuda.current_device(), dtype=params_dtype, ) ) setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) self.reset_layer_norm_parameters() # FC1 init self.fc1_weight = Parameter( torch.empty( self.size_per_partition, hidden_size, device=torch.cuda.current_device(), dtype=params_dtype, ) ) self.fp8_weight_shapes.append(self.fc1_weight.shape) initialize_affine_weight_gpu( self.fc1_weight, init_method, get_rng_state_tracker, partition_dim=0, stride=1, ) if self.use_bias: self.fc1_bias = Parameter( torch.empty( self.size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype, ) ) set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) else: self.register_buffer("fc1_bias", torch.Tensor().type(params_dtype), persistent=False) with torch.no_grad(): self.fc1_bias.zero_() # FC2 init self.fc2_weight = Parameter( torch.empty( hidden_size, self.size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype, ) ) self.fp8_weight_shapes.append(self.fc2_weight.shape) initialize_affine_weight_gpu( self.fc2_weight, output_layer_init_method, get_rng_state_tracker, partition_dim=1, stride=1, ) if self.use_bias: self.fc2_bias = Parameter( torch.empty( hidden_size, device=torch.cuda.current_device(), dtype=params_dtype ) ) else: self.register_buffer("fc2_bias", torch.Tensor().type(params_dtype), persistent=False) # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.set_parallel_mode and self.apply_bias: self.gemm_bias_unfused_add = True else: self.gemm_bias_unfused_add = False with torch.no_grad(): self.fc2_bias.zero_() if self.bias_gelu_nvfusion: set_jit_fusion_options() if seq_length and micro_batch_size: warmup_jit_bias_gelu_all_dtypes( self.size_per_partition, seq_length, micro_batch_size ) # These many SMs are subtracted from the total SM count when calling forward # and backward LayerNorm C APIs. These envvars can be used to prevent the LN # kernels from using all SMs in the device. This is useful for cases such as # communication overlap with LN. self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) def reset_layer_norm_parameters(self) -> None: """Init LN params""" if not self.zero_centered_gamma: init.ones_(self.layer_norm_weight) else: init.zeros_(self.layer_norm_weight) init.zeros_(self.layer_norm_bias) def forward( self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a feedforward network (MLP Block). Parameters ---------- inp : torch.Tensor Input tensor. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split into microbatches. Between the microbatches of the same minibatch the model weights are not updated. Setting this parameter indicates whether the current microbatch is the first in a minibatch or not. When set, this parameter enables additional optimizations: * during FP8 training, it allows caching of the FP8 versions of the weights * it also allows skipping gradient accumulation during the first microbatch (since it is the first gradient being produced) """ with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: if torch.is_grad_enabled(): fwd_fn = _LayerNormMLP.apply args = [] else: fwd_fn = _LayerNormMLP.forward args = [None] args += ( inp, self.layer_norm_weight, self.layer_norm_bias, self.fc1_weight, self.weight1_fp8 if self.fp8 else None, self.weight1_t_fp8 if self.fp8 else None, self.fc1_bias, self.use_bias, self.fc2_weight, self.weight2_fp8 if self.fp8 else None, self.weight2_t_fp8 if self.fp8 else None, self.fc2_bias, self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, self.fp8, self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, self.tp_group, self.tp_size, self.sequence_parallel, self.tp_size > 1, self.activation_dtype, self.return_layernorm_output, self.bias_gelu_nvfusion, self.set_parallel_mode, torch.is_grad_enabled(), self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_split_rs, self.ub_split_ag, ) out = fwd_fn(*args) if self.return_layernorm_output: out, ln_out = out if self.gemm_bias_unfused_add: out = out + cast_if_needed(self.fc2_bias, self.activation_dtype) if self.return_bias: if self.return_layernorm_output: return out, cast_if_needed(self.fc2_bias, self.activation_dtype), ln_out return out, cast_if_needed(self.fc2_bias, self.activation_dtype) if self.return_layernorm_output: return out, ln_out return out class _LayerNorm(torch.autograd.Function): """functional LayerNorm""" @staticmethod def forward( ctx, inp: torch.Tensor, ln_weight: torch.Tensor, ln_bias: torch.Tensor, eps: float, fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, ) -> torch.Tensor: # Make sure input dimensions are compatible in_features = ln_weight.numel() assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.shape[-1] == in_features, "LayerNorm not possible" inputmat = inp.view((-1, in_features)) ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma) ctx.save_for_backward(inputmat, ln_weight, mu, rsigma) ctx.inp_shape = inp.shape ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma return ln_out.view_as(inp) @staticmethod def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: inputmat, ln_weight, mu, rsigma = ctx.saved_tensors grad_output = grad_output.contiguous() d_ln_out = grad_output.view(inputmat.shape) dxmat, dgamma, dbeta = tex.layernorm_bwd( d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ) return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None class LayerNorm(torch.nn.Module): r""" Applies Layer Normalization over a mini-batch of inputs as described in the paper `Layer Normalization `__ .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of size :attr:`hidden_size` Parameters ---------- hidden_size : int size of each input sample. eps : float, default = 1e-5 a value added to the denominator of layer normalization for numerical stability. sequence_parallel : bool, default = `False` if set to `True`, uses sequence parallelism. params_dtype : torch.dtype, default = `torch.float32` it controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision and the original FP32 parameters would not fit in GPU memory. zero_centered_gamma : bool, default = 'False' if set to 'True', gamma parameter in LayerNorm is initialized to 0 and the LayerNorm formula changes to .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta """ def __init__( self, hidden_size: int, eps: float = 1e-5, sequence_parallel: bool = False, params_dtype: torch.dtype = torch.float32, zero_centered_gamma: bool = False, ) -> None: super().__init__() self.eps = eps self.zero_centered_gamma = zero_centered_gamma self.weight = Parameter( torch.empty( hidden_size, device=torch.cuda.current_device(), dtype=params_dtype, ) ) self.bias = Parameter( torch.empty( hidden_size, device=torch.cuda.current_device(), dtype=params_dtype, ) ) setattr(self.weight, "sequence_parallel", sequence_parallel) setattr(self.bias, "sequence_parallel", sequence_parallel) self.reset_layer_norm_parameters() # These many SMs are subtracted from the total SM count when calling forward # and backward LayerNorm C APIs. These envvars can be used to prevent the LN # kernels from using all SMs in the device. This is useful for cases such as # communication overlap with LN. self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) def load_state_dict( self, state_dict: Mapping[str, Any], strict: bool = True, ) -> None: """Override PyTorch loader to maintain backward compatibility with previous version of LayerNorm parameter names. """ if "layer_norm_weight" in state_dict: state_dict["weight"] = state_dict["layer_norm_weight"] del state_dict["layer_norm_weight"] if "layer_norm_bias" in state_dict: state_dict["bias"] = state_dict["layer_norm_bias"] del state_dict["layer_norm_bias"] super().load_state_dict(state_dict, strict) def reset_layer_norm_parameters(self) -> None: """Init LN params""" if not self.zero_centered_gamma: init.ones_(self.weight) else: init.zeros_(self.weight) init.zeros_(self.bias) def forward(self, inp: torch.Tensor) -> torch.Tensor: """LayerNorm FWD""" # Maintain backward compatibility. if hasattr(self, "layer_norm_weight"): setattr(self, "weight", self.layer_norm_weight) if hasattr(self, "layer_norm_bias"): setattr(self, "bias", self.layer_norm_bias) return _LayerNorm.apply( inp, self.weight, self.bias, self.eps, self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma )