# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Methods needed for distributed training (DP/TP).""" from contextlib import contextmanager from typing import Any, Dict, Union, Optional, Callable, Tuple import torch from torch.cuda import _lazy_call from torch.utils.checkpoint import detach_variable from .utils import safely_set_viewless_tensor_data from .constants import dist_group_type from .fp8 import is_fp8_enabled _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { "tensor_model_parallel": False, "partition_dim": -1, "partition_stride": 1, } _FP8_ACTIVATION_RECOMPUTE_ENABLED = False _FP8_ACTIVATION_RECOMPUTE_PHASE = False def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -> None: """Sets the random number generator state of the current GPU. Arguments: new_state (torch.ByteTensor): The desired state This function is adapted from PyTorch repo (torch.cuda.set_rng_state) with a single change: the input state is not cloned. Cloning caused major performance issues for +4 GPU cases. """ if device == -1: device = torch.device("cuda") elif isinstance(device, str): device = torch.device(device) elif isinstance(device, int): device = torch.device("cuda", device) def cb() -> None: idx = device.index if idx is None: idx = torch.cuda.current_device() default_generator = torch.cuda.default_generators[idx] default_generator.set_state(new_state) _lazy_call(cb) def set_tensor_model_parallel_attributes( tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int ) -> None: """set attributes needed for TP""" for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: assert not hasattr(tensor, attribute) # Set the attributes. setattr(tensor, "tensor_model_parallel", is_parallel) setattr(tensor, "partition_dim", dim) setattr(tensor, "partition_stride", stride) def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int: """Return world size for the distributed group.""" if not torch.distributed.is_initialized(): return 1 return torch.distributed.get_world_size(group=group) def get_distributed_rank(group: Optional[dist_group_type] = None) -> int: """Return my rank for the distributed group.""" assert torch.distributed.is_initialized(), "torch.distributed is not initialized." return torch.distributed.get_rank(group=group) def initialize_affine_weight_gpu( weight: torch.Tensor, init_method: Callable, get_rng_state_tracker: Callable, partition_dim: int, stride: int = 1, ) -> None: """Initialize affine weight for model parallel on GPU.""" set_tensor_model_parallel_attributes( tensor=weight, is_parallel=True, dim=partition_dim, stride=stride ) if get_rng_state_tracker is None: init_method(weight) return with get_rng_state_tracker().fork(): init_method(weight) def split_tensor_into_1d_equal_chunks( tensor: torch.Tensor, tp_group: dist_group_type, new_buffer: bool = False ) -> torch.Tensor: """Break a tensor into equal 1D chunks.""" partition_size = torch.numel(tensor) // get_distributed_world_size(tp_group) start_index = partition_size * get_distributed_rank(tp_group) end_index = start_index + partition_size if new_buffer: data = torch.empty( partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False, ) data.copy_(tensor.view(-1)[start_index:end_index]) else: data = tensor.view(-1)[start_index:end_index] return data def gather_split_1d_tensor( tensor: torch.Tensor, tp_group: dist_group_type ) -> torch.Tensor: """Opposite of above function, gather values from model parallel ranks.""" numel_gathered = torch.numel(tensor) * get_distributed_world_size(tp_group) gathered = torch.empty( numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False, ) torch.distributed.all_gather_into_tensor(gathered, tensor, group=tp_group) return gathered @contextmanager def activation_recompute_forward( activation_recompute: bool = False, recompute_phase: bool = False, ) -> None: """Context manager used to control the forward runtime behavior when executed under the `CheckpointFunction` function. For running FP8, the forward pass will run without storing intermediate activations. Instead, the forward pass saves the inputs tuple and the calling function. In the backwards pass, these are retrieved, and the forward pass is computed again while tracking the intermediate activations, followed by calculation of gradients using these values. """ global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE try: _FP8_ACTIVATION_RECOMPUTE_ENABLED = activation_recompute and is_fp8_enabled() _FP8_ACTIVATION_RECOMPUTE_PHASE = recompute_phase yield finally: _FP8_ACTIVATION_RECOMPUTE_ENABLED = False _FP8_ACTIVATION_RECOMPUTE_PHASE = False def is_fp8_activation_recompute_enabled() -> bool: """Return global boolean""" return _FP8_ACTIVATION_RECOMPUTE_ENABLED def in_fp8_activation_recompute_phase() -> bool: """Return global boolean""" return _FP8_ACTIVATION_RECOMPUTE_PHASE class CheckpointFunction(torch.autograd.Function): """This function is adapted from torch.utils.checkpoint with two main changes: 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` 2) the states in the model parallel tracker are also properly tracked/set/reset. """ @staticmethod def forward( ctx, run_function: Callable, distribute_saved_activations: bool, get_cuda_rng_tracker: Callable, tp_group: dist_group_type, kwargs: Dict[str, Any], *args: Tuple[torch.Tensor, ...], ) -> Tuple[torch.Tensor, ...]: """Call forward function while saving state to be able to redo the computation later.""" ctx.run_function = run_function ctx.distribute_saved_activations = distribute_saved_activations # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() with torch.no_grad(): with activation_recompute_forward( activation_recompute=True, recompute_phase=False ): outputs = run_function(*args, **kwargs) # Divide hidden states across model parallel group and only keep # the chunk corresponding to the current rank. if distribute_saved_activations: ctx.input_0_shape = args[0].data.shape safely_set_viewless_tensor_data( args[0], split_tensor_into_1d_equal_chunks( args[0].data, tp_group, new_buffer=True ), ) # Store everything. ctx.save_for_backward(*args) ctx.get_cuda_rng_tracker = get_cuda_rng_tracker ctx.tp_group = tp_group ctx.kwargs = kwargs return outputs @staticmethod def backward( ctx, *args: Tuple[Union[torch.Tensor, None], ...] ) -> Tuple[Union[torch.Tensor, None], ...]: """Call backward function with activation recomputation.""" if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad(), " "please use .backward() if possible" ) inputs = ctx.saved_tensors get_cuda_rng_tracker = ctx.get_cuda_rng_tracker if ctx.distribute_saved_activations: safely_set_viewless_tensor_data( inputs[0], gather_split_1d_tensor(inputs[0].data, ctx.tp_group).view( ctx.input_0_shape ), ) # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() bwd_cuda_rng_state = torch.cuda.get_rng_state() bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() # Set the states to what it used to be before the forward pass. torch.set_rng_state(ctx.fwd_cpu_rng_state) _set_cuda_rng_state(ctx.fwd_cuda_rng_state) get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) # Compute the forward pass. detached_inputs = detach_variable(inputs) with torch.enable_grad(): with activation_recompute_forward( activation_recompute=True, recompute_phase=True ): outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) # Set the states back to what it was at the start of this function. torch.set_rng_state(bwd_cpu_rng_state) _set_cuda_rng_state(bwd_cuda_rng_state) get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) if isinstance(outputs, torch.Tensor): outputs = (outputs,) torch.autograd.backward(outputs, args) grads = tuple( inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs ) return (None, None, None, None, None) + grads def checkpoint( function: Callable, distribute_saved_activations: bool, get_cuda_rng_tracker: Callable, tp_group: dist_group_type, *args: Tuple[torch.Tensor, ...], **kwargs: Dict[str, Any], ) -> Tuple[torch.Tensor, ...]: """ Checkpoint a part of the model by trading compute for memory. This function is based on `torch.utils.checkpoint.checkpoint `_. .. warning:: It is the user's responsibility to ensure identical behavior when calling :attr:`function` from the forward and backward pass. If different output is produced (e.g. due to global state), then the checkpointed version won't be numerically equivalent. .. warning:: The tuple :attr:`args` must contain only tensors (or :attr:`None`) in order to comply with PyTorch's :attr:`save_for_backward` method. :attr:`function` must be callable to produce valid outputs with the inputs :attr:`args` and :attr:`kwargs`. Parameters ---------- function: Callable whether or not to enable fp8 distribute_saved_activations: bool if set to `True`, the first tensor argument is distributed across the specified tensor parallel group (`tp_group`) before saving it for the backward pass. get_cuda_rng_tracker: `Callable` python function with the functionality to retrieve a state via :attr:`state = get_cuda_rng_tracker().get_states()` and to reset the state via :attr:`get_cuda_rng_tracker().set_states(state)`. This is used to ensure any extra cuda rng state or general global state can be reproduced across the 2 forward phases; original and recompute. tp_group : ProcessGroup, default = `None` tensor parallel process group. args : tuple tuple of torch tensors for inputs to :attr:`function`. kwargs : dict dictionary of string keys for keyword arguments to :attr:`function`. """ return CheckpointFunction.apply( function, distribute_saved_activations, get_cuda_rng_tracker, tp_group, kwargs, *args, ) def reduce_scatter_along_first_dim( input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """Reduce-scatter the input tensor across model parallel group.""" world_size = get_distributed_world_size(tp_group) # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_, None dim_size = list(input_.size()) assert ( dim_size[0] % world_size == 0 ), "First dimension of the tensor should be divisible by tensor parallel size" dim_size[0] = dim_size[0] // world_size output = torch.empty( dim_size, dtype=input_.dtype, device=torch.cuda.current_device() ) handle = torch.distributed.reduce_scatter_tensor( output, input_.contiguous(), group=tp_group, async_op=async_op ) return output, handle def gather_along_first_dim( input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """Gather tensors and concatinate along the first dimension.""" world_size = get_distributed_world_size(tp_group) # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_, None dim_size = list(input_.size()) dim_size[0] = dim_size[0] * world_size output = torch.empty( dim_size, dtype=input_.dtype, device=torch.cuda.current_device() ) handle = torch.distributed.all_gather_into_tensor( output, input_.contiguous(), group=tp_group, async_op=async_op ) return output, handle def gather_along_last_dim( input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """Gather tensors and concatinate along the last dimension.""" world_size = get_distributed_world_size(tp_group) # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_, None dim_size = list(input_.size()) dim_size[-1] = dim_size[-1] * world_size output = torch.empty( dim_size, dtype=input_.dtype, device=torch.cuda.current_device() ) handle = torch.distributed.all_gather_into_tensor( output, input_.contiguous(), group=tp_group, async_op=async_op ) return output, handle def allreduce( input_: torch.Tensor, tp_group: Optional[dist_group_type] = None, async_op: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """All-reduce the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. if get_distributed_world_size(tp_group) == 1: return input_, None # All-reduce. handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op) return input_, handle