Commit d444a97a authored by yangzhong's avatar yangzhong
Browse files

首次上传

parents
Pipeline #3020 canceled with stages
## How to use pytorch FSDP2?
Add these flag to enable Torch FSDP2.
```
--use-torch-fsdp2
--no-gradient-accumulation-fusion
--ckpt-format torch_dist
```
It is worth noting that CUDA_MAX_CONNECTIONS=1 should not be enabled to ensure that the communication of FSDP and the computation on the primary stream can be fully parallelized.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from packaging.version import Version
from .distributed_data_parallel import DistributedDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .finalize_model_grads import finalize_model_grads
from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from contextlib import contextmanager
import torch
from ..transformer.module import MegatronModule
from ..transformer.transformer_config import TransformerConfig
class _BaseDataParallel(MegatronModule):
"""A template class for DistributedDataParallel implementations."""
def __init__(self, config: TransformerConfig, module: torch.nn.Module):
super().__init__(config=config)
self.module = module
def forward(self, *inputs, **kwargs):
"""
Calls the wrapped module's forward() method.
"""
return self.module(*inputs, **kwargs)
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
try:
yield
finally:
pass
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
pass
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale all gradients inside the buffers by `scaling_factor`."""
pass
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
pass
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
pass
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
pass
def state_dict(self, prefix='', keep_vars=False):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""
Returns wrapped module's state_dict for checkpoint saving.
"""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
self.module.load_state_dict(state_dict, strict=strict)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from contextlib import contextmanager
import torch
from .. import parallel_state
from ..config_logger import has_config_logger_enabled, log_config_to_disk
from ..transformer.cuda_graphs import is_graph_capturing
from ..transformer.transformer_config import TransformerConfig
from ..utils import is_float8tensor, log_single_rank
from .data_parallel_base import _BaseDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets
logger = logging.getLogger(__name__)
class DistributedDataParallel(_BaseDataParallel):
"""
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
communication with backprop computation by breaking up full model's gradients into smaller
buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class
also provides the option to do the gradient accumulation in a type other than the param type
(e.g., fp32 for a bf16 model).
Args:
config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
"""
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
disable_bucketing: bool = False,
):
super().__init__(config=config, module=module)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.module = module
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if ddp_config.bucket_size is None:
ddp_config.bucket_size = max(
40000000, 1000000 * parallel_state.get_data_parallel_world_size()
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if not ddp_config.overlap_grad_reduce:
ddp_config.bucket_size = None
self.ddp_config = ddp_config
log_single_rank(
logger,
logging.INFO,
f'Setting up DistributedDataParallel with config {self.ddp_config}',
)
# Turn off bucketing if we are on a pipeline stage that is not the first (since
# data-parallel communication on these stages is not on the critical path), or if
# disable_bucketing is True (e.g., we might not want to break up model parameters
# into buckets for model chunks after the first in the interleaved schedule).
self.bucket_size = self.ddp_config.bucket_size
if parallel_state.get_pipeline_model_parallel_rank() > 0:
self.bucket_size = None
if disable_bucketing:
self.bucket_size = None
self.param_to_bucket_group = {}
# Group parameters by their gradient type.
param_to_name = {}
dense_params = []
expert_parallel_params = []
self.params_with_grad = []
for name, param in self.module.named_parameters():
if not param.requires_grad:
continue
# Track params with grad to enable direct setting
# of param.grad_added_to_main_grad
self.params_with_grad.append(param)
param.grad_added_to_main_grad = False
param_to_name[param] = name
if getattr(param, 'allreduce', True):
dense_params.append(param)
else:
expert_parallel_params.append(param)
def _allocate_buffers_for_parameters(
input_params, data_parallel_group, gradient_scaling_factor
):
param_and_grad_dtype_to_params = {}
param_and_grad_dtype_to_offsets = {}
param_and_grad_dtype_to_indices = {}
# Group parameters by their gradient type.
for param in input_params:
assert param.requires_grad
param_dtype = param.dtype
if is_float8tensor(param):
# Currently TE's Float8Tensor is a wrapper of torch.Tensor. It has a "fake"
# dtype (usually a higher precision dtype such as bfloat16), but its actual
# data is stored in the form of a torch uint8 tensor within the Float8Tensor's
# ".data" attribute. Therefore, when creating the param buffer for fp8 params,
# it is necessary to use torch.uint8, not the "fake" dtype got from
# "param.dtype".
param_dtype = torch.uint8
grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype
params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), [])
params.append(param)
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params
# Get the index of each param among the params with same dtype, if a param is fp8,
# use its "fake" high precision dtype to find which params have same dtype with it.
# For example:
# Case 1:
# params = [p1(bf16), p2(bf16), p3(bf16), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 1, 2, 3],
# }
# Case 2:
# params = [p1(bf16), p2(fp8), p3(fp8), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 3],
# (torch.uint8, torch.float32): [1, 2],
# }
# We need these indices to load a non-native-fp8 checkpoint in native-fp8 mode.
offset = param_and_grad_dtype_to_offsets.get((param.dtype, grad_dtype), 0)
param_and_grad_dtype_to_offsets[(param.dtype, grad_dtype)] = offset + 1
indices = param_and_grad_dtype_to_indices.get((param_dtype, grad_dtype), [])
indices.append(offset)
param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)] = indices
if not config.calculate_per_token_loss:
target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size(
with_context_parallel=True
)
if self.ddp_config.average_in_collective:
if self.ddp_config.num_distributed_optimizer_instances == 1:
# Collective is averaging gradients in collective with data_parallel_group.
assert (
gradient_scaling_factor
/ torch.distributed.get_world_size(group=data_parallel_group)
== target_gradient_scaling_factor
)
else:
# For non-expert parameters, gradient_scaling_factor is 1.
# For expert parameters, gradient_scaling_factor is 1/ep_size.
assert (gradient_scaling_factor == 1) or (
gradient_scaling_factor
== (1.0 / parallel_state.get_expert_model_parallel_world_size())
)
else:
assert gradient_scaling_factor == target_gradient_scaling_factor
# Allocate the grad buffers and map the grads.
buffers = []
for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items():
buffers.append(
_ParamAndGradBuffer(
self.ddp_config,
param_dtype,
grad_dtype,
params,
data_parallel_group,
self.bucket_size,
param_to_name,
gradient_scaling_factor,
param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)],
)
)
# In some scenarios, we want to put buckets from different buffers into a group so that
# their communication can be aggregated. For example, when there are both fp8 buffers
# and bf16 buffers in the model and vpp is enabled, each model chunk will have an fp8
# bucket and a bf16 bucket, which doubles the number of communication kernels, and
# because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back
# communications will prevent the overlap of the communication kernels with computation
# kernels.
# If bucketing is explicitly disabled, then put all buckets in a buffer into a single
# bucket group.
bucket_groups = partition_buckets(buffers, force_single_bucket_group=disable_bucketing)
if self.ddp_config.num_distributed_optimizer_instances > 1:
assert (
self.ddp_config.use_distributed_optimizer
), 'Partial DistOpt cannot be used without DistOpt'
communication_stream = torch.cuda.Stream(device=torch.cuda.current_device())
for bucket_group in bucket_groups:
bucket_group.inter_distributed_optimizer_instance_group = (
parallel_state.get_inter_partial_data_parallel_group()
)
bucket_group.communication_stream = communication_stream
# Set `next_param_gather_bucket_group` for different bucket groups by iterating through
# buckets in reverse order (since all-gathers happen in reverse order of buckets).
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
num_bucket_groups = len(bucket_groups)
for i in range(1, num_bucket_groups):
bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = (
bucket_groups[num_bucket_groups - i - 1]
)
# Create map from param to bucket group, used in pre_hook.
for bucket_group in bucket_groups:
for bucket in bucket_group.buckets:
for param in bucket.params_list:
self.param_to_bucket_group[param] = bucket_group
return buffers, bucket_groups
if config.calculate_per_token_loss:
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = 1.0
else:
if self.ddp_config.average_in_collective:
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = (
1.0 / parallel_state.get_expert_model_parallel_world_size()
)
else:
data_parallel_world_size = parallel_state.get_data_parallel_world_size(
with_context_parallel=True
)
gradient_scaling_factor = 1.0 / data_parallel_world_size
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
# Allocate the param+grad buffers for dense params' grads.
self.buffers, self.bucket_groups = _allocate_buffers_for_parameters(
dense_params,
parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=True
),
gradient_scaling_factor=gradient_scaling_factor,
)
# Allocate separate param+grad buffers for expert parallel params' grads.
self.expert_parallel_buffers, self.expert_parallel_bucket_groups = (
_allocate_buffers_for_parameters(
expert_parallel_params,
parallel_state.get_expert_data_parallel_group(),
gradient_scaling_factor=expert_gradient_scaling_factor,
)
)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
if self.ddp_config.use_distributed_optimizer:
@torch.no_grad()
def unmap_weight_tensor(m):
if hasattr(m, 'weight_tensor'):
m.weight_tensor = None
self.module.apply(unmap_weight_tensor)
# Register backward hook.
# Accumulation function for the gradients need to be stored so they
# don't go out of scope.
self.grad_accs = []
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_backward_post_hook(param))
self.grad_accs.append(grad_acc)
self.use_forward_hook = (
self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather
)
self.remove_forward_pre_hook_handles = {}
if self.use_forward_hook:
self.enable_forward_pre_hook()
self.overlap_param_gather_with_optimizer_step = False
def enable_forward_pre_hook(self):
"""
Enable forward pre-hooks needed for param all-gather overlap with forward compute.
"""
assert self.use_forward_hook
assert len(self.remove_forward_pre_hook_handles) == 0
# Register forward pre-hook for all sub-modules.
for module in self.module.modules():
self.remove_forward_pre_hook_handles[module] = module.register_forward_pre_hook(
self._make_forward_pre_hook()
)
def disable_forward_pre_hook(self, param_sync: bool = True):
"""
Disable forward pre-hooks needed for param all-gather overlap with forward compute.
Skip synchronous param all-gather if `param_sync` is False.
"""
assert self.use_forward_hook
# De-register forward pre-hook for all sub-modules.
for module in self.module.modules():
assert self.remove_forward_pre_hook_handles[module] is not None
self.remove_forward_pre_hook_handles[module].remove()
del self.remove_forward_pre_hook_handles[module]
assert len(self.remove_forward_pre_hook_handles) == 0
# Force synchronize parameters.
if param_sync:
self.start_param_sync(force_sync=True)
def _make_forward_pre_hook(self):
"""
Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
when a module uses a parameter in a bucket with a still incomplete all-gather).
"""
def hook(module, *unused):
assert (
self.use_forward_hook
), "Should use pre-hook only when overlap_param_gather is True"
if is_graph_capturing():
return
# Make sure all parameters in this module have been all-gathered as necessary.
for param in module.parameters(recurse=False):
# Skip parameters without an associated buffer (such parameters have a
# .requires_grad field equal to False).
if param not in self.param_to_bucket_group:
continue
assert param.requires_grad
# If aligning param all-gather across pipeline stages, all-gather is dispatched
# by start_param_sync calls in core/pipeline_parallelism/schedules.py.
# If overlapping param all-gather with optimizer step, then all-gather has
# already been dispatched in optimizer step.
skip_next_bucket_dispatch = (
self.ddp_config.align_param_gather
or self.overlap_param_gather_with_optimizer_step
)
self.param_to_bucket_group[param].finish_param_sync(
skip_next_bucket_dispatch=skip_next_bucket_dispatch
)
return hook
def _make_backward_post_hook(self, param: torch.nn.Parameter):
"""
Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when
ready (i.e., when all grads in a bucket have been computed in all microbatches
in a batch).
"""
def hook(*unused):
if is_graph_capturing():
return
if param in self.param_to_bucket_group:
assert param.requires_grad
if self.ddp_config.overlap_grad_reduce:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
param.main_grad.add_(param.grad.data)
param.grad = None
if self.ddp_config.overlap_grad_reduce:
self.param_to_bucket_group[param].register_grad_ready(param)
return hook
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.is_last_microbatch = False
try:
yield
finally:
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.is_last_microbatch = True
def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False):
"""
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
calls; when overlap_param_gather is set to False, calls synchronous communication
ops. Can override this default behavior using flags below.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings.
force_dispatch (bool, optional): force dispatch regardless of other settings.
"""
if not force_sync:
# If overlapping param AG with optimizer step, AG should not be dispatched again
# in forward_backward_step.
if self.overlap_param_gather_with_optimizer_step and not force_dispatch:
return
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.start_param_sync(force_sync=force_sync)
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.start_grad_sync()
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.finish_grad_sync()
def scale_gradients(self, scaling_factor: float):
"""Scale all gradients inside the buffers by `scaling_factor`."""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.scale_gradients(scaling_factor)
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
for param in self.params_with_grad:
param.grad_added_to_main_grad = False
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.reset()
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.reset()
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
for param in self.module.parameters():
is_expert_parallel = not getattr(param, 'allreduce', True)
if is_expert_parallel:
data_parallel_group = parallel_state.get_expert_data_parallel_group()
else:
data_parallel_group = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=True
)
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_global_rank(data_parallel_group, 0),
group=data_parallel_group,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Optional
@dataclass
class DistributedDataParallelConfig:
"""Configuration for DistributedDataParallel."""
grad_reduce_in_fp32: bool = False
"""If true, reduce grads in fp32."""
overlap_grad_reduce: bool = False
"""If true, overlap grad all-reduce / reduce-scatter with backward compute."""
overlap_param_gather: bool = False
"""If true, overlap param all-gather with forward compute."""
align_param_gather: bool = False
"""If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each
PP stage will independently launch as needed.
"""
use_distributed_optimizer: bool = False
"""If true, issue reduce-scatter collectives to aggregate gradients and clean up
originally allocated model parameters, otherwise issue all-reduce collectives.
"""
num_distributed_optimizer_instances: int = 1
"""Sets the factor by which the DP domain is sharded to have the partial DistOpt
enabled. Defaults to 1, which means DistOpt is across entire DP domain.
"""
check_for_nan_in_grad: bool = False
""" If true, check for NaNs in gradients _before_ communication collective."""
bucket_size: Optional[int] = None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound)."""
average_in_collective: bool = False
"""If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective."""
fp8_param_gather: bool = False
"""If true, keep the compute param in fp8 (do not use any other intermediate dtype) and
perform the param all-gather in fp8."""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List, Optional, Union
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
try:
from torch.distributed._tensor import DTensor, distribute_tensor
HAVE_DTENSOR = True
except ImportError:
HAVE_DTENSOR = False
from .. import parallel_state
from ..transformer.transformer_config import TransformerConfig
from ..utils import get_attr_wrapped_model, get_model_config
def _unshard_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor:
"""
Unshards the input tensor if it is a DTensor and otherwise returns the
tensor unmodified.
Args:
tensor (Union[torch.Tensor, DTensor]): The tensor to potentially unshard.
Returns:
An unsharded version of the input tensor if it is a DTensor, or the
input tensor unmodified if it is not a DTensor.
"""
if HAVE_DTENSOR and isinstance(tensor, DTensor):
unsharded_tensor = tensor.full_tensor()
for k, v in vars(tensor).items():
setattr(unsharded_tensor, k, v)
return unsharded_tensor
return tensor
def _reshard_if_dtensor(
tensor_to_shard: torch.Tensor, reference_tensor: Union[torch.Tensor, "DTensor"]
) -> Union[torch.Tensor, "DTensor"]:
"""
Reshards the input tensor to match the sharding configuration of the
reference tensor if the reference tensor is a DTensor. Otherwise, returns
the reference tensor unmodified.
Args:
tensor_to_shard (torch.Tensor): The tensor to be potentially sharded.
reference_tensor (Union[torch.Tensor, DTensor]): The reference tensor
for the sharding configuration.
Returns:
Union[torch.Tensor, DTensor]: The sharded tensor matching the reference tensor's
configuration, or the reference tensor itself if it is not a DTensor.
"""
if HAVE_DTENSOR and isinstance(reference_tensor, DTensor):
sharded_tensor = distribute_tensor(
tensor_to_shard,
device_mesh=reference_tensor.device_mesh,
placements=reference_tensor.placements,
)
for k, v in vars(reference_tensor).items():
setattr(sharded_tensor, k, v)
return sharded_tensor
return reference_tensor
def _allreduce_conditional_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce conditional embedding grads.
Reduce grads across all the pp stages to ensure that parameters of the conditional embedders
(e.g., timestep embedder, FPS embedder, label embedder) stay in sync.
This is for the models with replicated embedders on each PP / VPP rank, like diffusion models.
"""
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and getattr(
config, "has_cond_embedder", False
):
grads_dict = {}
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if param.requires_grad and getattr(param, 'pipeline_parallel', False):
grad = param.main_grad
if name in grads_dict:
# Add all the virtual PP rank's gradients to
# the first local virtual PP rank.
grads_dict[name][0].add_(grad)
# Append to the end for later update after cross-rank reduce.
grads_dict[name].append(grad)
else:
grads_dict[name] = [grad]
if grads_dict:
# All-reduce the gradient on the first VPP rank.
grads = [param_grad[0] for _, param_grad in grads_dict.items()]
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_pipeline_model_parallel_group()
)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
# Update the gradients on other VPP ranks.
for grads in grads_dict.values():
for grad in grads[1:]:
grad.copy_(grads[0])
def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync.
"""
if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce position_embeddings grad across encoder and decoder stages to ensure that position
embeddings parameters stay in sync.
"""
if (
parallel_state.is_rank_in_position_embedding_group()
and torch.distributed.get_world_size(parallel_state.get_position_embedding_group()) > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
assert hasattr(model_module, 'position_embeddings')
weight = model_module.position_embeddings.weight
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce both word and position embeddings.
"""
_allreduce_word_embedding_grads(model, config)
_allreduce_position_embedding_grads(model, config)
def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce layernorm grads (for sequence parallelism).
"""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if parallel_state.get_tensor_model_parallel_world_size() > 1 and (
config.sequence_parallel or config.qk_layernorm
):
params = []
grads = []
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if (
param.requires_grad
and getattr(param, 'sequence_parallel', False)
or 'q_layernorm' in name
or 'k_layernorm' in name
):
params.append(param)
grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad"
grad = getattr(param, grad_attr)
grad = _unshard_if_dtensor(grad)
grads.append(grad.data)
if grads:
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_tensor_model_parallel_group()
)
for param, buf, synced in zip(
params, grads, _unflatten_dense_tensors(coalesced, grads)
):
buf.copy_(synced)
grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad"
orig_grad = getattr(param, grad_attr)
setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))
def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
config = get_model_config(model[0])
# All-reduce / reduce-scatter across DP replicas.
if config.timers is not None:
config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
for model_chunk in model:
model_chunk.finish_grad_sync()
if config.timers is not None:
config.timers('all-grads-sync').stop()
# All-reduce t_embedder grads (for pp & vpp of DiT).
if config.timers is not None:
config.timers('conditional-embedder-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_conditional_embedding_grads(model, config)
if config.timers is not None:
config.timers('conditional-embedder-grads-all-reduce').stop()
# All-reduce layer-norm grads (for sequence parallelism).
if config.timers is not None:
config.timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_layernorm_grads(model, config)
if config.timers is not None:
config.timers('layernorm-grads-all-reduce').stop()
# All-reduce embedding grads (for pipeline parallelism).
if config.timers is not None:
config.timers('embedding-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_embedding_grads(model, config)
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if num_tokens is not None:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
last_rank = parallel_state.get_pipeline_model_parallel_last_rank()
pp_group = parallel_state.get_pipeline_model_parallel_group()
if not isinstance(last_rank, list):
assert not isinstance(last_rank, list)
last_rank = [last_rank]
assert not isinstance(pp_group, list)
pp_group = [pp_group]
# need to do a broadcast for every pp group, even though num_tokens should be the same.
num_tokens_list = []
for lr, group in zip(last_rank, pp_group):
torch.distributed.broadcast(num_tokens, src=lr, group=group)
num_tokens_list.append(torch.clone(num_tokens))
assert all(x.item() == num_tokens_list[0] for x in num_tokens_list)
# all-reduce across DP ranks.
torch.distributed.all_reduce(num_tokens, group=parallel_state.get_data_parallel_group())
for model_chunk in model:
if num_tokens > 0:
scaling = 1.0 / num_tokens
model_chunk.scale_gradients(scaling)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
import math
from contextlib import nullcontext
from enum import Enum
from typing import Dict, List, Optional
import torch
from torch.distributed import _coalescing_manager
from megatron.core.rerun_state_machine import get_rerun_state_machine
from ..utils import is_float8tensor, is_torch_min_version, log_on_each_pipeline_stage
from .distributed_data_parallel_config import DistributedDataParallelConfig
logger = logging.getLogger(__name__)
if is_torch_min_version("1.13.0"):
dist_all_gather_func = torch.distributed.all_gather_into_tensor
dist_reduce_scatter_func = torch.distributed.reduce_scatter_tensor
else:
dist_all_gather_func = torch.distributed._all_gather_base
dist_reduce_scatter_func = torch.distributed._reduce_scatter_base
class BufferType(Enum):
"""
Enumeration for buffer type.
"""
PARAM = 1
GRAD = 2
def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int):
"""
Shard buffer into data_parallel_world_size chunks of equal size.
"""
assert buffer.numel() % data_parallel_world_size == 0
shard_size = buffer.numel() // data_parallel_world_size
sharded_buffer = [
buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size)
]
return sharded_buffer
class _ParamAndGradBucket:
"""
Bucket to keep track of a subset of the model's parameters and gradients.
Args:
params: List of parameters whose gradients are collated in this bucket.
param_data: View in _ParamAndGradBuffer.param_data that this bucket is responsible for.
grad_data: View in _ParamAndGradBuffer.grad_data that this bucket is responsible for.
offset: Offset of this bucket's view in the larger _ParamAndGradBuffer.
numel_unpadded: Number of unpadded elements in bucket.
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
bucket_id: Index of bucket in buffer.
"""
def __init__(
self,
params: List[torch.nn.Parameter],
param_data: Optional[torch.Tensor],
grad_data: torch.Tensor,
offset: int,
numel_unpadded: int,
gradient_scaling_factor: float,
bucket_id: int,
):
self.params_list = params
self.params = set(params)
# Make sure there are no duplicate params.
assert len(self.params_list) == len(self.params)
self.param_data = param_data
self.grad_data = grad_data
# The distributed optimizer needs to keep track of this bucket's offset
# within the full grad_buffer.
self.offset = offset
self.numel_unpadded = numel_unpadded
self.gradient_scaling_factor = gradient_scaling_factor
self.bucket_id = bucket_id
class _ParamAndGradBucketGroup:
"""
Put multiple buckets into a group so that their communications can be aggregated together.
Provides functionality to register when params in the bucket group have grads ready to be
synced; an asynchronous communication call is automatically launched when _all_ params in
the bucket group have grads ready.
Args:
buckets: A list of buckets.
ddp_config: DistributedDataParallel config object.
collective_group: intra_distributed_optimizer_instance_group if using distributed
optimizer, data_parallel_group if not.
collective_group_size: World size using the intra data-parallel group.
"""
def __init__(
self,
buckets: List[_ParamAndGradBucket],
ddp_config: DistributedDataParallelConfig,
collective_group: torch.distributed.ProcessGroup,
collective_group_size: int,
):
self.buckets = buckets
self.ddp_config = ddp_config
if self.ddp_config.use_distributed_optimizer:
self.intra_distributed_optimizer_instance_group = collective_group
self.intra_distributed_optimizer_instance_size = collective_group_size
self.intra_distributed_optimizer_instance_rank = torch.distributed.get_rank(
group=collective_group
)
else:
self.data_parallel_group = collective_group
# State for bookkeeping: params is the set of parameters this bucket group is
# responsible for, params_with_grad is the set of parameters with grads
# available. When overlap_grad_reduce is True, communication (all-reduce
# or reduce-scatter) is issued when params_with_grad equals params.
self.param_to_bucket = {}
self.params = set()
for bucket in self.buckets:
for param in bucket.params_list:
self.param_to_bucket[param] = bucket
self.params.add(param)
self.next_param_gather_bucket_group = None
if self.ddp_config.num_distributed_optimizer_instances > 1:
self.inter_distributed_optimizer_instance_group = None
self.communication_stream = None
self.reset()
self.param_gather_handle = None
self.param_gather_dispatched = False
self.grad_reduce_handle = None
def reset(self):
"""
Reset metadata in bucket group in preparation for the next iteration of training.
"""
self.params_with_grad = set()
self.is_last_microbatch = True
def check_for_nan_in_grad(self):
"""
Make sure norm of grads in bucket are not NaN prior to data-parallel
all-reduce / reduce-scatter.
"""
rerun_state_machine = get_rerun_state_machine()
for i in range(len(self.buckets)):
rerun_state_machine.validate_result(
result=self.buckets[i].grad_data.norm(p=2),
rejection_func=torch.isnan,
message=f"found NaN in local grad norm for bucket #{i} "
f"in backward pass before data-parallel communication collective",
tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward
fatal=True,
)
def start_param_sync(self, force_sync: bool = False):
"""
Initiates all necessary param all-gathers for this bucket.
When ddp_config.overlap_param_gather is set to True, dispatches an asynchronous
communication call (unless force_sync is True). When ddp_config.overlap_param_gather
is set to False, makes synchronous call.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings if true.
"""
assert self.ddp_config.use_distributed_optimizer
if force_sync:
if self.param_gather_handle is not None:
self.param_gather_handle.wait()
self.param_gather_handle = None
return
else:
assert self.param_gather_handle is None
async_op = self.ddp_config.overlap_param_gather and not force_sync
# Coalesce communication kernels across buckets in the bucket group.
with _coalescing_manager(
self.intra_distributed_optimizer_instance_group, async_ops=async_op
) as cm:
for bucket in self.buckets:
local_data_view = shard_buffer(
bucket.param_data, self.intra_distributed_optimizer_instance_size
)[self.intra_distributed_optimizer_instance_rank]
dist_all_gather_func(
bucket.param_data,
local_data_view,
group=self.intra_distributed_optimizer_instance_group,
async_op=async_op,
)
if async_op:
self.param_gather_handle = cm
else:
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
# which case the torch.distributed._all_gather_base() will return None. In order to
# maintain consistency with prior code, we need to manually set communication handle to
# None.
self.param_gather_handle = None
self.param_gather_dispatched = True
def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
"""
Finishes param sync communication operation for this bucket. Dispatches
next bucket's param sync if available, unless skip_next_bucket_dispatch
is True.
When ddp_config.overlap_param_gather is set to True, waits for asynchronous
communication call to complete (and dispatches one if one is not already
outstanding). Throws assertion error if ddp_config.overlap_param_gather is set to
False.
Args:
skip_next_bucket_dispatch (bool, optional): if true, dispatch next
bucket's communication if available.
"""
assert self.ddp_config.use_distributed_optimizer
assert self.ddp_config.overlap_param_gather
# If current bucket's param AG has not been dispatched, dispatch it now (e.g., first
# AG bucket in first model chunk if ddp_config.align_param_gather is False).
if not self.param_gather_dispatched:
self.start_param_sync()
if self.param_gather_handle is not None:
self.param_gather_handle.wait()
self.param_gather_handle = None
# Dispatch next bucket's asynchronous param AG.
if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch:
self.next_param_gather_bucket_group.start_param_sync()
def start_grad_sync(self):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous
communication call. When ddp_config.overlap_grad_reduce is set to False, makes
synchronous call.
"""
assert (
self.grad_reduce_handle is None
), 'Should not have multiple communication calls outstanding at once'
if self.ddp_config.check_for_nan_in_grad:
self.check_for_nan_in_grad()
# gradient_scaling_factor already takes into account whether we are computing
# an average or sum in the data-parallel collective.
for bucket in self.buckets:
if bucket.gradient_scaling_factor != 1.0:
bucket.grad_data *= bucket.gradient_scaling_factor
# Decide reduce_op.
reduce_op = torch.distributed.ReduceOp.SUM
if self.ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG
# We use the following stream synchronization for the gradient reduction
# within and across DistOpt instances.
# Compute Stream: -------------Gradient compute-------------------
# Comm. Stream: ------(wait for NCCL)-----(wait for NCCL)-------
# NCCL Stream: -------RS------ -------AR------
# Use async communications only when overlap_grad_reduce is True.
async_op = (
self.ddp_config.overlap_grad_reduce
and self.ddp_config.num_distributed_optimizer_instances == 1
)
if (
self.ddp_config.num_distributed_optimizer_instances > 1
and self.ddp_config.overlap_grad_reduce
):
# Assign a communication stream if we have multiple DistOpt instances and we
# need to overlap communication.
stream_context = torch.cuda.stream(self.communication_stream)
# The RS/AR communication stream needs to wait for the default stream
# to complete its gradient computation before launching the next
# gradient reduction collective.
self.communication_stream.wait_stream(torch.cuda.default_stream())
else:
stream_context = nullcontext()
if self.ddp_config.use_distributed_optimizer:
communication_group = self.intra_distributed_optimizer_instance_group
else:
communication_group = self.data_parallel_group
# Coalesce communication kernels across buckets in the bucket group.
with stream_context, _coalescing_manager(communication_group, async_ops=async_op) as cm:
for bucket in self.buckets:
if self.ddp_config.use_distributed_optimizer:
local_data_view = shard_buffer(
bucket.grad_data, self.intra_distributed_optimizer_instance_size
)[self.intra_distributed_optimizer_instance_rank]
dist_reduce_scatter_func(
local_data_view,
bucket.grad_data,
op=reduce_op,
group=communication_group,
async_op=async_op,
)
else:
torch.distributed.all_reduce(
bucket.grad_data, op=reduce_op, group=communication_group, async_op=async_op
)
# With multiple DistOpt instances, we need to all-reduce across instances.
if (
self.ddp_config.use_distributed_optimizer
and self.ddp_config.num_distributed_optimizer_instances > 1
):
# Create a new coalescing manager for the inter-instance all-reduce.
with stream_context, _coalescing_manager(
self.inter_distributed_optimizer_instance_group, async_ops=async_op
) as cm:
for bucket in self.buckets:
local_data_view = shard_buffer(
bucket.grad_data, self.intra_distributed_optimizer_instance_size
)[self.intra_distributed_optimizer_instance_rank]
torch.distributed.all_reduce(
local_data_view,
op=reduce_op,
group=self.inter_distributed_optimizer_instance_group,
async_op=async_op,
)
if async_op:
self.grad_reduce_handle = cm
else:
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
# which case the torch.distributed._reduce_scatter_base() will return None. In order to
# maintain consistency with prior code, we need to manually set communication handle to
# None.
self.grad_reduce_handle = None
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous
communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
makes synchronous call.
"""
self.param_gather_dispatched = False
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
if not self.ddp_config.overlap_grad_reduce:
self.start_grad_sync()
return
# When using multiple DistOpt instances, we don't need to sync here as we launch
# communications on a separate communication stream.
if self.ddp_config.num_distributed_optimizer_instances > 1:
torch.cuda.default_stream().wait_stream(self.communication_stream)
return
assert self.grad_reduce_handle is not None, (
f'Communication call has not been issued for this bucket '
f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)'
)
self.grad_reduce_handle.wait()
self.grad_reduce_handle = None
def register_grad_ready(self, param: torch.nn.Parameter):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and ddp_config.overlap_grad_reduce
is True.
"""
assert (
self.ddp_config.overlap_grad_reduce
), 'register_grad_ready() should only be called when overlap_grad_reduce is True'
if self.is_last_microbatch:
assert param in self.param_to_bucket, 'Param is not in the bucket group'
assert param not in self.params_with_grad, 'Cannot set grad twice'
self.params_with_grad.add(param)
# If all params in bucket group have grads available, issue communication call.
if len(self.params_with_grad) == len(self.params):
self.start_grad_sync()
class _ParamAndGradBuffer:
"""
Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into
buckets with roughly `bucket_size` parameters each.
Args:
ddp_config: DistributedDataParallel config object.
param_dtype: Type of param tensor.
grad_dtype: Type of grad tensor.
params: List of parameters whose parameters and gradients are collated in the underlying
tensor.
data_parallel_group: Data-parallel process group.
bucket_size: The rough size of each bucket in terms of number of parameters.
param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes).
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
param_indices: The index of each param among the params with same dtype, if a param is fp8,
use its "fake" high precision dtype to determine which params have same dtype with it.
These indices are needed when loading a non-native-fp8 checkpoint in native-fp8 mode.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
param_dtype: torch.dtype,
grad_dtype: torch.dtype,
params: List[torch.nn.Parameter],
data_parallel_group: torch.distributed.ProcessGroup,
bucket_size: int,
param_to_name: Dict[torch.nn.Parameter, str],
gradient_scaling_factor: float,
param_indices: List[int],
):
self.ddp_config = ddp_config
self.params = params
self.param_indices = param_indices
# Check that params are unique.
unique_params = set()
for param in params:
assert param not in unique_params
unique_params.add(param)
del unique_params
# Store attributes that will be needed later.
self.param_dtype = param_dtype
self.grad_dtype = grad_dtype
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = torch.distributed.get_world_size(
group=self.data_parallel_group
)
self.gradient_scaling_factor = gradient_scaling_factor
# Data structures to store underlying buckets and relevant indexing data.
self.buckets = []
self.param_to_bucket = {} # Param -> bucket mapping.
self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer).
def _pad(number_to_be_padded: int, divisor: int) -> int:
return int(math.ceil(number_to_be_padded / divisor) * divisor)
def _pad_end_of_bucket_if_needed(bucket_end_index: int) -> int:
"""
Pads end index of bucket if using distributed optimizer (to ensure uniform sharding).
"""
if self.ddp_config.use_distributed_optimizer:
# Workaround for TE bug causing cuBLAS to pick an incompatible algorithm.
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
return _pad(bucket_end_index, math.lcm(self.data_parallel_world_size, 128))
return bucket_end_index
def _pad_start_of_param_if_needed(param_start_index: int) -> int:
"""
Pads start index of param if using distributed optimizer (to ensure "good" alignment).
"""
if self.ddp_config.use_distributed_optimizer:
# Ensure that params start at 128-byte aligned addresses (64 values
# since params are >= 16-bit precision).
return _pad(param_start_index, 64)
return param_start_index
# First, figure out how many elements should be in the underlying buffer storage.
# Note that if we need to split the buffer into smaller buckets, each of these
# might need to be padded as well (if using the distributed optimizer).
param_start_index = 0
bucket_start_index = param_start_index
bucket_params = set()
self.bucket_indices = []
per_bucket_numel_unpadded = []
bucket_id = 0
def _update_bucket_metadata(param_end_index: int) -> int:
"""
Record metadata for the bucket starting at bucket_start_index and ending with the
passed-in param_end_index. Returns the bucket's end_index.
"""
nonlocal bucket_start_index, bucket_params, bucket_id
per_bucket_numel_unpadded.append(param_end_index - bucket_start_index)
bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index)
# Record metadata of new bucket.
self.bucket_indices.append((bucket_start_index, bucket_end_index))
bucket_start_index = bucket_end_index
# Prepare for next bucket.
bucket_params = set()
bucket_id += 1
# Return the potentially padded bucket_end_index.
return bucket_end_index
def _does_param_require_new_bucket(param):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return (
getattr(param, "shared_embedding", False)
and self.ddp_config.use_distributed_optimizer
)
for param in params[::-1]:
# Iterate through parameters in reverse order to roughly follow backprop order.
this_numel = param.data.nelement()
param_start_index = _pad_start_of_param_if_needed(param_start_index)
# Create bucket with collected parameters if current param needs its own bucket.
if _does_param_require_new_bucket(param):
# We are creating a bucket for the already accumulated parameters, whose params
# end at the current param_start_index.
if self.ddp_config.use_distributed_optimizer:
# Make sure new bucket is appropriately padded.
if param_start_index % self.data_parallel_world_size != 0:
param_start_index = _pad_end_of_bucket_if_needed(param_start_index)
if len(bucket_params) > 0:
bucket_end_index = _update_bucket_metadata(param_start_index)
param_end_index = param_start_index + this_numel
self.param_index_map[param] = (param_start_index, param_end_index, bucket_id)
bucket_params.add(param)
# If we have enough elements already or the current param is part of the shared
# embedding layer and needs a separate bucket, form a new bucket.
if (
bucket_size is not None and (param_end_index - bucket_start_index) >= bucket_size
) or _does_param_require_new_bucket(param):
bucket_end_index = _update_bucket_metadata(param_end_index)
param_start_index = bucket_end_index
else:
param_start_index = param_end_index
# Add remaining params to a new bucket.
if len(bucket_params) > 0:
bucket_end_index = _update_bucket_metadata(param_end_index)
# Next, create underlying storage for buffer (with numel elements that includes
# padding as necessary).
self.numel = bucket_end_index
self.numel_unpadded = sum(per_bucket_numel_unpadded)
assert self.numel_unpadded <= self.numel
if self.ddp_config.use_distributed_optimizer:
assert self.numel % self.data_parallel_world_size == 0
else:
assert self.numel == self.numel_unpadded
self.param_data = None
# Only re-map param tensors if using distributed optimizer.
if self.ddp_config.use_distributed_optimizer:
self.param_data = torch.zeros(
self.numel,
dtype=self.param_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
self.grad_data = torch.zeros(
self.numel,
dtype=self.grad_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
# Finally, map param.data and param.main_grad fields to buffers.
bucket_params = []
bucket_start_index = 0
cur_bucket_id = 0
for param in params[::-1]:
param_start_index, param_end_index, bucket_id = self.param_index_map[param]
# Assign param.data to appropriate segment of self.param_data.
if self.param_data is not None:
old_param_data = param.data
new_param_data = self._get(
param.data.shape, param_start_index, buffer_type=BufferType.PARAM
)
if is_float8tensor(param):
param._data = new_param_data
else:
param.data = new_param_data
assert old_param_data._base is None
# Copy tensor values (from initialization or checkpoint).
param.data.detach().copy_(old_param_data)
del old_param_data
param.main_grad = self._get(
param.data.shape, param_start_index, buffer_type=BufferType.GRAD
)
if bucket_id != cur_bucket_id:
bucket_end_index = _pad_end_of_bucket_if_needed(param_start_index)
self.buckets.append(
self._new_bucket(
bucket_params=bucket_params,
start_index=bucket_start_index,
end_index=bucket_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
)
bucket_start_index = bucket_end_index
bucket_params = []
assert cur_bucket_id + 1 == len(self.buckets)
assert bucket_id == cur_bucket_id + 1
cur_bucket_id = bucket_id
bucket_params.append(param)
# Add remaining params to a new bucket.
if len(bucket_params) > 0:
bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index)
self.buckets.append(
self._new_bucket(
bucket_params=bucket_params,
start_index=bucket_start_index,
end_index=bucket_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
)
# Log buckets for all PP stages.
log_strs = []
log_strs.append(
f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}'
)
for index, bucket in enumerate(self.buckets):
numel = 0
for param in bucket.params:
numel += param.data.nelement()
log_strs.append(f'Params for bucket {index+1} ({numel} elements):')
for param in bucket.params:
log_strs.append(f'\t{param_to_name[param]}')
log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs))
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale the gradient data by `scaling_factor`."""
self.grad_data *= scaling_factor
def _get(self, shape: torch.Size, start_index: int, buffer_type: BufferType) -> torch.Tensor:
"""
Return a tensor with the input `shape` as a view into the 1-D data starting at
`start_index`.
"""
end_index = start_index + shape.numel()
assert end_index <= self.numel, 'Requested tensor is out of buffer range'
if buffer_type == BufferType.PARAM:
assert self.param_data is not None
buffer_tensor = self.param_data[start_index:end_index]
elif buffer_type == BufferType.GRAD:
buffer_tensor = self.grad_data[start_index:end_index]
else:
raise Exception("Illegal buffer type provided to GradBuffer._get() function")
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
def _new_bucket(
self,
bucket_params: List[torch.nn.Parameter],
start_index: int,
end_index: int,
numel_unpadded: int,
bucket_id: int,
) -> _ParamAndGradBucket:
"""
Helper function that creates a new bucket. Also updates param->bucket mapping.
"""
# Assert that indices are correctly padded (if needed), and that bucket
# position is same as originally computed.
if self.ddp_config.use_distributed_optimizer:
assert start_index % self.data_parallel_world_size == 0
assert end_index % self.data_parallel_world_size == 0
assert (start_index, end_index) == self.bucket_indices[bucket_id]
# Get appropriate view into global _ParamAndGradBuffer.
bucketed_param_data = None
if self.param_data is not None:
bucketed_param_data = self._get(
torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.PARAM
)
bucketed_grad_data = self._get(
torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.GRAD
)
bucket = _ParamAndGradBucket(
params=bucket_params,
param_data=bucketed_param_data,
grad_data=bucketed_grad_data,
offset=start_index,
numel_unpadded=numel_unpadded,
gradient_scaling_factor=self.gradient_scaling_factor,
bucket_id=bucket_id,
)
for bucket_param in bucket_params:
assert bucket_param not in self.param_to_bucket
self.param_to_bucket[bucket_param] = bucket
return bucket
def reset(self):
"""
Zero out the underlying grad_buffer.
"""
self.grad_data.zero_()
def partition_buckets(
buffers: List[_ParamAndGradBuffer], force_single_bucket_group: bool = False
) -> List[_ParamAndGradBucketGroup]:
"""
Automatically regroup the buckets of input buffers and return a list of bucket groups.
In some scenarios, we need to put buckets from different buffers into a group so that their
communication can be aggregated.
For example, when there are both fp8 weights and bf16 biases in the model and virtual
pipeline parallelism is enabled, each model chunk will have an fp8 bucket and a bf16 bucket,
which doubles the number of communication kernels, and because of the use of
CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back communications will prevent the
overlap of communication kernels with computation kernels.
The grouping strategy is:
1. If force_single_bucket_group is True, put all buckets across all buffers into a single
bucket group.
2. If force_single_bucket_group is False, when there is no fp8 buffer in the input buffers,
let each bucket group have only one bucket.
3. If force_single_bucket_group is False, when using fp8 params, merge all non-fp8 buckets
into the last fp8 bucket group.
- Since the non-fp8 parameters (typically the biases of various layers) are relatively
small, they are likely to be grouped into a single non-fp8 bucket.
- The fp8 buckets start from the end of the model, i.e., the first bucket corresponds to
the end of the model, while the last bucket corresponds to the beginning.
- If we combine the non-fp8 bucket with the first fp8 bucket, we cannot initiate the
reduce-scatter to synchronize gradients after the backward pass at the end of the model
has completed. This is because we need to wait for the non-fp8 params from the beginning
layers to obtain their gradients.
- Combining the non-fp8 bucket with the last fp8 bucket can help avoid this issue.
Args:
buffers (list): list of input buffers.
single_bucket_group_per_buffer (bool, optional): force group all buckets in each buffer
into a single bucket group.
"""
if len(buffers) == 0:
return []
dtype_to_buffer_map = {}
for buffer in buffers:
dtype = buffer.param_dtype
# Make sure that the param_dtype of any two buffers is different.
assert dtype not in dtype_to_buffer_map
dtype_to_buffer_map[dtype] = buffer
# Case 1: Put all buckets into a single bucket group if force_single_bucket_group is True.
if force_single_bucket_group:
buckets = []
ddp_config = buffers[0].ddp_config
data_parallel_group = buffers[0].data_parallel_group
data_parallel_world_size = buffers[0].data_parallel_world_size
for buffer in buffers:
assert ddp_config == buffer.ddp_config
assert data_parallel_group == buffer.data_parallel_group
assert data_parallel_world_size == buffer.data_parallel_world_size
buckets.extend(buffer.buckets)
bucket_group = _ParamAndGradBucketGroup(
buckets, ddp_config, data_parallel_group, data_parallel_world_size
)
return [bucket_group]
if torch.uint8 not in dtype_to_buffer_map:
# Case 2: When there is no fp8 buffer in the input buffers, let each bucket group have
# only one bucket.
bucket_groups = []
for buffer in buffers:
for bucket in buffer.buckets:
bucket_groups.append(
_ParamAndGradBucketGroup(
[bucket],
buffer.ddp_config,
buffer.data_parallel_group,
buffer.data_parallel_world_size,
)
)
return bucket_groups
else:
# Case 3: When using fp8 params, merge all non-fp8 buckets into the last fp8 bucket group.
non_fp8_buckets = []
for buffer in buffers:
if buffer.param_dtype != torch.uint8:
for bucket in buffer.buckets:
non_fp8_buckets.append(bucket)
bucket_groups = []
fp8_buffer = dtype_to_buffer_map[torch.uint8]
for bucket in fp8_buffer.buckets:
if len(bucket_groups) == len(fp8_buffer.buckets) - 1:
# The last bucket group.
group_buckets = [bucket] + non_fp8_buckets
else:
# The first N-1 bucket groups.
group_buckets = [bucket]
bucket_groups.append(
_ParamAndGradBucketGroup(
group_buckets,
buffer.ddp_config,
buffer.data_parallel_group,
buffer.data_parallel_world_size,
)
)
return bucket_groups
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List
import torch
try:
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
HAVE_FSDP = True
except ImportError:
HAVE_FSDP = False
from .. import parallel_state, tensor_parallel
from ..models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from ..models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from ..transformer.transformer_config import TransformerConfig
from ..transformer.transformer_layer import TransformerLayer
from .data_parallel_base import _BaseDataParallel
class TorchFullyShardedDataParallel(_BaseDataParallel):
"""
Enables fully sharded data parallelism by wrapping the given model with
the PyTorch FSDP2 API:
https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
To utilize this class, PyTorch version >= 2.4.0 is required.
Args:
config: Transformer config object.
module: Underlying model.
sub_modules_to_wrap: List of sub_modules to shard with FSDP.
Parameters within each sub_module will be all-gathered just-in-time.
The default list includes the following submodules derived from the
GPT model architecture:
TransformerLayer (all Transformer layers)
LanguageModelEmbedding (initial embedding layer)
RotaryEmbedding (initial RoPE layer)
tensor_parallel.ColumnParallelLinear (final output layer)
"""
def __init__(
self,
config: TransformerConfig,
module: torch.nn.Module,
sub_modules_to_wrap: List[torch.nn.Module] = [
TransformerLayer,
LanguageModelEmbedding,
RotaryEmbedding,
tensor_parallel.ColumnParallelLinear,
],
**kwargs
):
assert (
HAVE_FSDP
), 'TorchFullyShardedDataParallel requires PyTorch >= 2.4.0 with FSDP 2 support.'
super().__init__(config=config, module=module)
self.data_parallel_group = parallel_state.get_data_parallel_group(
with_context_parallel=True
)
mesh = DeviceMesh.from_group(self.data_parallel_group, "cuda")
kwargs = {"mesh": mesh}
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs
def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
# Save the custom attributes on Parameters before FSDP overwrites them.
# See https://github.com/pytorch/pytorch/issues/136929.
attrs = save_custom_attrs(self.module)
prev_module = None
for sub_module in self.module.modules():
# Wrap individual submodules to fetch parameters just-in-time rather than
# conservatively fetching all parameters at the start of each iteration.
# See https://github.com/pytorch/pytorch/issues/114299.
if any(
isinstance(sub_module, sub_module_to_wrap)
for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, **kwargs)
# Explicitly set the FSDP backward prefetch schedule to prevent activation
# recomputation from disrupting the automatically generated default schedule.
if config.recompute_granularity is not None:
sub_module.set_modules_to_backward_prefetch(
[prev_module] if prev_module else []
)
prev_module = sub_module
# Wrap the root module as required by the FSDP API.
# See https://github.com/pytorch/pytorch/issues/114299.
fully_shard(self.module, **kwargs)
restore_custom_attrs(self.module, attrs)
def load_state_dict(self, state_dict, strict=True):
"""
No-op because tensors are already loaded in-place by
`_load_base_checkpoint` with FSDP2."""
pass
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
retro_encoder = 3
retro_decoder = 4
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from enum import Enum
DataType = Enum('DataType', ["bfloat16", "float16", "float32"])
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
@dataclass
class ExportConfig:
"""Base configuration for Megatron Core Export
These parameters control the export setting for trtllm
"""
inference_tp_size: int = 1
inference_pp_size: int = 1
use_parallel_embedding: bool = False
use_embedding_sharing: bool = False
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from enum import Enum
ModelType = Enum(
'ModelType', ["gpt", "gptnext", "llama", "falcon", "starcoder", "mixtral", "gemma"]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment