Commit 1106877d authored by jerrrrry's avatar jerrrrry
Browse files

“13.0”

parents
Pipeline #2934 failed with stages
in 0 seconds
## How to use pytorch FSDP2?
Add these flag to enable Torch FSDP2.
```
--use-torch-fsdp2
--no-gradient-accumulation-fusion
--ckpt-format torch_dist
```
It is worth noting that CUDA_MAX_CONNECTIONS=1 should not be enabled to ensure that the communication of FSDP and the computation on the primary stream can be fully parallelized.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
try:
from packaging.version import Version
except ImportError:
pass
from .distributed_data_parallel import DistributedDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .finalize_model_grads import finalize_model_grads
from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel
from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .fully_sharded_data_parallel import FullyShardedDataParallel
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import functools
import logging
from contextlib import contextmanager
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from megatron.core import parallel_state
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.distributed.custom_fsdp.param_and_grad_buffer import (
AllGatherPipeline,
BucketingPolicy,
GradReducePipeline,
ParamAndGradBuffer,
PrefetchOrder,
override_sharded_param_methods_with_safety_checks,
)
from megatron.core.distributed.data_parallel_base import _BaseDataParallel
from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
from megatron.core.fp8_utils import is_float8tensor
from megatron.core.process_groups_config import GradCommProcessGroups
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.utils import is_submodule, log_single_rank
logger = logging.getLogger(__name__)
class TrainingState(Enum):
"""States of a FSDP parameter group, which are coupled with
the sharding activity of parameters and gradients during training."""
# From pre-forward before post-forward, where parameters should be unsharded
FORWARD = auto()
# Prior to backward computation, where parameters should be unsharded
PRE_BACKWARD = auto()
# After backward computation, where gradients should be re-sharded
POST_BACKWARD = auto()
# Before and after module forward computaton or before pre-backward and
# after post-backward states, where no un/sharding activity happens
IDLE = auto()
class FullyShardedDataParallel(_BaseDataParallel):
"""Fully Sharded Data Parallel training for MCore models.
A distributed training wrapper that shards model parameters, gradients and optimizer
states across data parallel workers. Integrates seamlessly with MCore's tensor
and expert parallelism features.
We supports following modes:
- no_shard: Traditional data parallel training without parameter sharding.
- optim: Shards optimizer states, this is conceptually close to "ZeRO-1", and
main weights for mixed precision training, meanwhile the following `optim_grads`
and `optim_grads_params` will also sharding main weights
during mixed-precision training, omitted without detailed notation.
- optim_grads: Shards gradients and optimizer states, this is conceptually close to "ZeRO-2".
- optim_grads_params: Shards parameters, gradients and optimizer states, this
is conceptually close to "ZeRO-3".
Key Features:
- Compatible with MCore's tensor, context and expert parallelism
- Automatic mixed precision training (BF16/FP8)
- Gradient accumulation and bucketing
- Optimized activation recompute with shard-aware communication: When recomputing
a whole Transformer layer, gather parameters once for both the recomputation
and backward computation
- Compatible with MCore's distributed checkpointing
Args:
config: Transformer config object.
ddp_config: FullyShardedDataParallel config object.
module: Underlying model.
fsdp_unit_modules: List of modules that should be treated as FSDP Unit,
i.e., the minimum releasable model unit. If not provided, defaults to
[TransformerLayer, LanguageModelEmbedding] for GPT-like models. In
addition to this, it affects the granularity of the communication
parameter grouping and triggers aggregate collective communication
in fp8 mixed precision training.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket.
grad_comm_pgs: Optional GradCommProcessGroups object. If not provided, the default
process groups from parallel_state will be used. If provided, module expects
grad_comm_pgs to have dp_cp or dp (if cp=1) and
expt_dp attributes(if using expert data parallelism).
Examples:
>>> model = GPTModel(config)
>>> model = FullyShardedDataParallel(
... config,
... model,
... ddp_config,
... fsdp_unit_modules = [TransformerLayer, LanguageModelEmbedding],
... )
"""
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
fsdp_unit_modules: Optional[List[torch.nn.Module]] = None,
disable_bucketing: bool = False,
device: Optional[torch.device] = None,
grad_comm_pgs: Optional[GradCommProcessGroups] = None,
):
super().__init__(config=config, module=module)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.module = module
self.ddp_config = ddp_config
log_single_rank(
logger,
logging.INFO,
f'Setting up DistributedDataParallel with config {self.ddp_config}',
)
# Check if the module has expert parameters.
self.contains_expert_parameters = False
for _, param in self.module.named_parameters():
if not getattr(param, 'allreduce', True):
self.contains_expert_parameters = True
break
# Initialize the data parallel and expert data parallel groups.
self.inter_fsdp_group_grad_reduce = self.ddp_config.num_distributed_optimizer_instances > 1
self.inter_distopt_group = None
self.expt_dp_group = None
self.intra_expt_dp_group = None
if grad_comm_pgs is None:
self.dp_cp_group = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=False
)
self.intra_dp_cp_group = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=True
)
self.expt_dp_group = parallel_state.get_expert_data_parallel_group()
self.intra_expt_dp_group = parallel_state.get_expert_data_parallel_group(
partial_expert_data_parallel=True
)
if self.inter_fsdp_group_grad_reduce:
self.inter_distopt_group = (
parallel_state.get_inter_distributed_optimizer_instance_group()
)
else:
cp_size = getattr(config, 'context_parallel_size', 1)
if hasattr(grad_comm_pgs, 'dp_cp'):
self.dp_cp_group = grad_comm_pgs.dp_cp
elif hasattr(grad_comm_pgs, 'dp') and cp_size == 1:
self.dp_cp_group = grad_comm_pgs.dp
else:
raise ValueError(
"Required process group missing: 'dp_cp' (or 'dp' when context_parallel_size=1)"
)
if self.contains_expert_parameters:
assert hasattr(
grad_comm_pgs, 'expt_dp'
), 'expert process group is required when using expert parameters'
self.expt_dp_group = grad_comm_pgs.expt_dp
if self.inter_fsdp_group_grad_reduce:
self.intra_expt_dp_group = self.expt_dp_group
else:
self.intra_expt_dp_group = grad_comm_pgs.intra_expt_dp
if self.inter_fsdp_group_grad_reduce:
self.inter_distopt_group = grad_comm_pgs.inter_dist_opt
self.intra_dp_cp_group = grad_comm_pgs.intra_dp_cp
else:
self.intra_dp_cp_group = self.dp_cp_group
self.bucket_size = self.ddp_config.bucket_size
if disable_bucketing:
self.bucket_size = None
self.device = device if device else torch.cuda.current_device()
self.param_to_bucket_group = {}
if fsdp_unit_modules is not None:
self.fsdp_unit_modules = fsdp_unit_modules
else:
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
self.fsdp_unit_modules = [TransformerLayer]
else:
self.fsdp_unit_modules = []
self.main_weights = True
# Determine if we should delay the gradient reduction.
self.is_delay_grad_reduce = self.ddp_config.data_parallel_sharding_strategy in [
"no_shard",
"optim",
]
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
assert self.ddp_config.overlap_param_gather
if not self.is_delay_grad_reduce:
assert self.ddp_config.overlap_grad_reduce
self._init_fsdp_param_and_grad_buffer()
self._register_fsdp_hooks(self.module)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
@torch.no_grad()
def unmap_weight_tensor(m):
if hasattr(m, 'weight_tensor'):
m.weight_tensor = None
self.module.apply(unmap_weight_tensor)
def _init_fsdp_param_and_grad_buffer(self):
if self.config.calculate_per_token_loss:
# We don't need to scale the gradients in this case.
gradient_scaling_factor = None
expert_gradient_scaling_factor = None
else:
if self.ddp_config.average_in_collective:
gradient_scaling_factor = 1.0
if self.contains_expert_parameters:
expert_gradient_scaling_factor = (
self.expt_dp_group.size() / self.dp_cp_group.size()
)
else:
expert_gradient_scaling_factor = None
else:
data_parallel_world_size = self.dp_cp_group.size()
gradient_scaling_factor = 1.0 / data_parallel_world_size
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
# Initialize the param and grad buffer.
self.data_parallel_sharding_strategy = self.ddp_config.data_parallel_sharding_strategy
self.param_to_name = {p: name for name, p in self.module.named_parameters()}
self.param_and_grad_buffer = ParamAndGradBuffer(
self.ddp_config,
self.module,
bucketing_policy=BucketingPolicy(
suggested_bucket_size=self.bucket_size,
fsdp_unit_modules=self.fsdp_unit_modules,
data_parallel_sharding_strategy=self.data_parallel_sharding_strategy,
),
data_parallel_group=self.intra_dp_cp_group,
expert_data_parallel_group=self.intra_expt_dp_group,
inter_data_parallel_group=self.inter_distopt_group,
preserve_fp32_weights=self.ddp_config.preserve_fp32_weights,
grad_reduce_in_fp32=self.ddp_config.grad_reduce_in_fp32,
gradient_scaling_factor=gradient_scaling_factor,
expert_gradient_scaling_factor=expert_gradient_scaling_factor,
device=self.device,
reset_parameters_for_meta_device_init_module=self.config.init_model_with_meta_device,
)
self.param_and_grad_buffer
self.side_stream_for_buffer_copy_and_grad_accum = torch.cuda.Stream()
# Initialize the reduce-scatter pipeline.
self.grad_reduce_pipeline = GradReducePipeline(
self.param_and_grad_buffer,
rs_stream=self.side_stream_for_buffer_copy_and_grad_accum,
inter_fsdp_group_grad_reduce=self.inter_fsdp_group_grad_reduce,
)
# Initialize the all-gather pipeline.
self.all_gather_pipeline = AllGatherPipeline(self.param_and_grad_buffer)
suggested_communication_unit_size = self.ddp_config.suggested_communication_unit_size
if suggested_communication_unit_size is None:
if self.data_parallel_sharding_strategy == "optim_grads_params":
total_param_elements = 0
total_fsdp_module = 0
for module in self.module.modules():
if isinstance(module, tuple(self.fsdp_unit_modules)):
total_fsdp_module += 1
total_param_elements += sum(p.numel() for p in module.parameters())
# The suggested size is twice the number of elements in the FSDP modules.
# This ensures we process the current FSDP module and attempt to prefetch
# the next FSDP module, making the flow of communication better.
suggested_communication_unit_size = total_param_elements // total_fsdp_module * 2
elif self.bucket_size is not None:
suggested_communication_unit_size = self.bucket_size * 2
self.suggested_RS_queue_capacity = suggested_communication_unit_size
self.suggested_AG_prefetch_size = suggested_communication_unit_size
if self.data_parallel_sharding_strategy == "optim_grads_params":
override_sharded_param_methods_with_safety_checks(
self.module.parameters(), self.all_gather_pipeline
)
def _register_fsdp_hooks(self, root_module):
"""Register necessary hooks for Fully Sharded Data Parallel (FSDP) execution on the model.
This function sets up various hooks required for FSDP operations, including parameter
resharding/unsharding and gradient handling. The registered hooks are:
- Pre-forward hook: Unshards parameters before forward pass
- Post-forward hook: Reshards parameters after forward pass
- Pre-backward hook: Unshards parameters before backward pass
- Post-backward hook: Reshards parameters and reduces gradients after backward pass
Args:
root_module: The PyTorch module to register FSDP hooks on
Note:
These hooks are essential for FSDP's memory efficiency as they manage:
1. Dynamic parameter sharding/unsharding to reduce memory footprint
2. Proper gradient synchronization across distributed processes
3. Gradient accumulation for large batch training
Returns:
None
"""
# Initialize module training state.
for m in root_module.modules():
setattr(m, "_training_state", TrainingState.IDLE)
self.forward_pre_hooks = {}
self.forward_hooks = {}
self.backward_pre_hooks = {}
"""
An FSDP unit is a module designed to manage the lifecycle of model parameters
in Fully Sharded Data Parallel (FSDP) training. It ensures that parameters
are only used within the module and are released immediately after
the forward and backward computations are completed.
This approach is crucial for efficient memory management, as releasing
parameters too early can lead to issues if other computations depend on them.
`optim` and `optim_grads` do not require FSDP units because they do not
shard model parameters.
"""
fsdp_unit_modules = self.fsdp_unit_modules
def release_module_parameters(module, *unused):
for param in module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.release_bucket(bucket_id)
if not self.ddp_config.keep_fp8_transpose_cache_when_using_custom_fsdp:
release_params_fp8_transpose_cache(module.parameters())
def release_params_fp8_transpose_cache(params):
for param in params:
if is_float8tensor(param):
param._transpose_invalid = True
param._transpose = None
def all_gather_module_parameters(
module,
*unused,
prefetch=True,
prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER,
wait_bucket_ready=True,
):
ag_pipeline = self.all_gather_pipeline
ag_pipeline.all_gather_params(
params=list(module.parameters()),
prefetch=prefetch,
prefetch_order=prefetch_order,
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
)
if wait_bucket_ready:
for param in module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
ag_pipeline.wait_bucket_ready(bucket_id)
def _grad_acc(param):
"""
Accumulate the gradient in the main_grad buffer.
"""
group_id = self.param_and_grad_buffer.param_to_param_group[param]
group = self.param_and_grad_buffer.parameter_groups[group_id]
if not group.requires_grad:
return
overwrite_main_grad = self.ddp_config.data_parallel_sharding_strategy in [
"optim_grads",
"optim_grads_params",
]
if overwrite_main_grad:
if not param.grad_added_to_main_grad:
# Get `main_grad` will allocate bucket, check that the currently
# used main_grad buffer does not exceed the scope of two FSDP Unit
# Modules, i.e., the buffer limit imposed by double-buffer allocator.
if self.ddp_config.fsdp_double_buffer:
self.grad_reduce_pipeline._enforce_double_buffer_limit([group_id])
if param.grad is not None:
param.main_grad.copy_(param.grad)
del param.grad
else:
param.main_grad.zero_()
else:
if not param.grad_added_to_main_grad:
if param.grad is not None:
param.main_grad.add_(param.grad)
del param.grad
# Reset the grad accumulate flag.
param.grad_added_to_main_grad = False
self._params_require_handle_grad = set()
def _post_backward(module, *unused):
if isinstance(module, tuple(fsdp_unit_modules)):
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
release_module_parameters(module)
module._training_state = TrainingState.IDLE
param_list = list(module.parameters())
else:
param_list = list(module.parameters(recurse=False))
for param in param_list:
_grad_acc(param)
self._params_require_handle_grad.discard(param)
grad_reduce_every_bprop = self.ddp_config.data_parallel_sharding_strategy in [
"optim_grads",
"optim_grads_params",
]
if grad_reduce_every_bprop or self.is_last_microbatch:
self.grad_reduce_pipeline.reduce_gradients(
param_list,
suggested_queue_capacity=self.suggested_RS_queue_capacity,
inter_fsdp_group_grad_reduce=(
self.inter_fsdp_group_grad_reduce and self.is_last_microbatch
),
)
def _pre_forward_param_unshard(
module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
):
# Unshard the parameters before the forward pass.
input_training_state = module._training_state
fsdp_forward_prefetch = True
if input_training_state == TrainingState.PRE_BACKWARD:
# In activation recomputation case, we need to cancel forward prefetch.
fsdp_forward_prefetch = False
else:
module._training_state = TrainingState.FORWARD
if isinstance(module, tuple(fsdp_unit_modules)):
param_list = list(module.parameters())
self.all_gather_pipeline.all_gather_params(
params=param_list,
prefetch=fsdp_forward_prefetch,
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
)
for param in param_list:
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
else:
# All-gather the parameters in every forward pass for FSDP.
param_list = list(module.parameters(recurse=False))
self.all_gather_pipeline.all_gather_params(
params=param_list,
prefetch=fsdp_forward_prefetch,
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
)
for param in param_list:
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
return args, kwargs
def _register_post_backward_hook(
post_backward_hook: callable,
module: nn.Module,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
):
# Register the backward function to reduce gradients after the backward pass.
# And for optim_grads_params, we need to release the parameters after the backward pass.
if not torch.is_grad_enabled():
return args, kwargs
args_list, args_spec = tree_flatten(args)
kwargs_list, kwargs_spec = tree_flatten(kwargs)
args_kwargs_list = list(args_list) + list(kwargs_list)
inp_tensor_indices: List[int] = []
inp_tensors: List[torch.Tensor] = []
for i, obj in enumerate(args_kwargs_list):
if torch.is_tensor(obj) and obj.requires_grad:
inp_tensor_indices.append(i)
inp_tensors.append(obj)
if len(inp_tensors) == 0:
return args, kwargs
inp_tensors = RegisterFSDPBackwardFunction.apply(
functools.partial(post_backward_hook, module), *inp_tensors
)
for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
args_kwargs_list[inp_tensor_idx] = inp_tensor
args_list = args_kwargs_list[: len(args_list)]
kwargs_list = args_kwargs_list[len(args_list) :]
args = tree_unflatten(args_list, args_spec)
kwargs = tree_unflatten(kwargs_list, kwargs_spec)
return args, kwargs
fsdp_modules = []
for name, module in root_module.named_modules():
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
continue
if isinstance(module, tuple(fsdp_unit_modules)):
fsdp_modules.append(module)
self.forward_pre_hooks[f'module {name} parameter unshard'] = (
module.register_forward_pre_hook(
_pre_forward_param_unshard, prepend=True, with_kwargs=True
)
)
self.forward_pre_hooks[f"module {name} register post-backward hook"] = (
module.register_forward_pre_hook(
functools.partial(_register_post_backward_hook, _post_backward),
with_kwargs=True,
)
)
def _root_post_backward(*unused):
# Make sure all the gradients are handled.
for param in self._params_require_handle_grad:
_grad_acc(param)
# Reduce the remain gradients.
grad_reduce_every_bprop = self.ddp_config.data_parallel_sharding_strategy in [
"optim_grads",
"optim_grads_params",
]
if grad_reduce_every_bprop or self.is_last_microbatch:
self.grad_reduce_pipeline.reduce_gradients(
list(self._params_require_handle_grad),
suggested_queue_capacity=self.suggested_RS_queue_capacity,
inter_fsdp_group_grad_reduce=(
self.inter_fsdp_group_grad_reduce and self.is_last_microbatch
),
)
self.grad_reduce_pipeline.reset()
# Reset root_pre_backward_hook_issued flag.
self._root_pre_backward_hook_issued = False
def _pre_backward(module: nn.Module, *unused):
module._training_state = TrainingState.PRE_BACKWARD
if isinstance(module, tuple(fsdp_unit_modules)):
all_gather_module_parameters(
module, prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER
)
self._root_pre_backward_hook_issued = False
def _root_pre_backward(module: nn.Module, *unused):
"""Marks the module's training state as 'pre_backward' before the
backprop, this function is registered on the root module.
This marking enables us to determine whether forward pass needs to
perform reshard/unshard operations in activation recomputation
scenarios.
"""
if self._root_pre_backward_hook_issued:
return
self._root_pre_backward_hook_issued = True
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
for module in root_module.modules():
if isinstance(module, tuple(fsdp_unit_modules)):
module._training_state = TrainingState.PRE_BACKWARD
for param in module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.wait_bucket_ready(bucket_id, empty_ok=True)
self.all_gather_pipeline.release_bucket(bucket_id)
self._params_require_handle_grad = set()
for param_group in self.param_and_grad_buffer.parameter_groups:
if not param_group.requires_grad:
continue
self._params_require_handle_grad |= set(param_group.params)
for param in param_group.params:
param.grad_added_to_main_grad = False
torch.autograd.Variable._execution_engine.queue_callback(_root_post_backward)
def _post_forward(module: nn.Module, input: Any, output: Any):
# When composing with module-hook-based activation checkpointing, the
# post-backward hook is responsible for the reshard
if module._training_state == TrainingState.PRE_BACKWARD:
return output
release_module_parameters(module)
module._training_state = TrainingState.IDLE
return output
def _release_module_fp8_transpose_cache(module: nn.Module, *unused):
release_params_fp8_transpose_cache(module.parameters(recurse=False))
if len(fsdp_unit_modules) != 0:
fsdp_modules = []
for name, module in root_module.named_modules():
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
continue
if isinstance(module, tuple(fsdp_unit_modules)):
fsdp_modules.append(module)
self.forward_hooks[f"release module {name} parameters"] = (
module.register_forward_hook(_post_forward, prepend=False)
)
self.backward_pre_hooks[f"all-gather module {name} parameters"] = (
module.register_full_backward_pre_hook(_pre_backward)
)
elif not self.ddp_config.keep_fp8_transpose_cache_when_using_custom_fsdp:
self.forward_hooks[f"remove module {name} fp8 transpose cache"] = (
module.register_forward_hook(
_release_module_fp8_transpose_cache, prepend=False
)
)
# Registering all models with all parameters is to handle some special cases
# where the forward function of root_module is not called, but the forward
# functions of these equivalent modules are called instead.
for name, module in root_module.named_modules():
if len(list(module.parameters())) != len(list(root_module.parameters())):
continue
self.backward_pre_hooks[f"{name} _root_pre_backward"] = (
module.register_full_backward_pre_hook(_root_pre_backward)
)
self._root_pre_backward_hook_handle = root_module.register_full_backward_pre_hook(
_root_pre_backward
)
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
For grads shard mode there will actually always be gradient sync happening.
"""
# FIXME: Better handling of grads shard mode and no_sync in the training loop so that
# the code doesn't bog down developers.
self.is_last_microbatch = False
try:
yield
finally:
self.is_last_microbatch = True
def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False):
"""
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
calls; when overlap_param_gather is set to False, calls synchronous communication
ops. Can override this default behavior using flags below.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings.
force_dispatch (bool, optional): force dispatch regardless of other settings.
"""
if not force_sync and self.ddp_config.overlap_param_gather:
# All-gather the first bucket before the forward pass.
first_param = list(self.module.parameters())[0]
self.all_gather_pipeline.all_gather_params(params=[first_param], prefetch=False)
else:
self.all_gather_pipeline.reset()
for bucket_id in range(self.all_gather_pipeline.num_buckets):
self.all_gather_pipeline.async_bucket_gather(bucket_id)
group = self.param_and_grad_buffer.parameter_groups[bucket_id]
if group.model_weight_buffer is None:
continue
if group.model_weight_buffer.is_data_distributed:
# If model weight is sharded, we wait for the all-gather to complete and
# then release the bucket immediately to save memory usage.
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
for bucket_id in range(self.all_gather_pipeline.num_buckets):
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
if not self.ddp_config.overlap_grad_reduce:
if self.data_parallel_sharding_strategy == "no_shard":
self.param_and_grad_buffer.all_reduce_gradients(
async_op=self.ddp_config.overlap_grad_reduce
)
else:
self.param_and_grad_buffer.reduce_scatter_gradients()
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
if self.ddp_config.overlap_grad_reduce:
self.grad_reduce_pipeline.wait_for_previous_grad_reduce(0)
self.grad_reduce_pipeline.reset()
else:
self.start_grad_sync()
self.param_and_grad_buffer.update_main_grads()
if self.ddp_config.overlap_param_gather:
self.all_gather_pipeline.reset()
def optimizer_named_parameters(self) -> List[Tuple[str, torch.Tensor]]:
"""
Returns a list of tuples containing the main weights and their corresponding names
for mixed-precision training, to be used by the optimizer for updates.
Returns:
List[Tuple[str, torch.Tensor]]: A list of tuples, where each tuple
contains a main weight tensor and its corresponding name.
"""
return self.param_and_grad_buffer.optimizer_named_parameters
def scale_gradients(self, scaling_factor: float):
"""Scale all gradients inside the buffers by `scaling_factor`."""
self.param_and_grad_buffer.scale_gradients(scaling_factor)
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
for param in self.module.parameters():
if param.requires_grad:
param.grad_added_to_main_grad = False
self.param_and_grad_buffer.zero_grad()
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
for param in self.module.parameters():
is_expert_parallel = not getattr(param, 'allreduce', True)
if is_expert_parallel:
data_parallel_group = self.expt_dp_group
else:
data_parallel_group = self.dp_cp_group
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_global_rank(data_parallel_group, 0),
group=data_parallel_group,
)
def load_state_dict(self, state_dict, strict=True):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
# make a copy of the state_dict to avoid modifying the input state_dict
state_dict = state_dict.copy()
state_dict_extra_states = {}
for key in list(state_dict.keys()):
if key.endswith("_extra_state"):
state_dict_extra_states[key] = state_dict[key]
del state_dict[key]
self.module.load_state_dict(state_dict_extra_states, strict=False)
prefix = "module."
buffer = self.param_and_grad_buffer
for param_groups in buffer.parameter_groups:
wbuf = param_groups.model_weight_buffer
for model_param in wbuf.params:
if is_float8tensor(model_param):
fp8_meta = model_param._fp8_meta['scaling_fwd']
fp8_meta_index = model_param._fp8_meta_index
model_param._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index])
param_name = f"{buffer.param_to_name[model_param]}"[len(prefix) :]
if param_name in state_dict:
if wbuf and wbuf.is_data_distributed:
model_param.fully_shard_param_local_shard.data.copy_(
state_dict[param_name]
)
else:
model_param.data.copy_(state_dict[param_name])
del state_dict[param_name]
self.module.load_state_dict(state_dict, strict=False)
return
self.module.load_state_dict(state_dict, strict=strict)
class RegisterFSDPBackwardFunction(torch.autograd.Function):
"""
Register a backward function that will be called after the backward pass
of the model. This function is used to release the parameters after the
backward pass.
"""
@staticmethod
def forward(ctx, post_backward, *inputs: torch.Tensor):
"""
Forward pass of the RegisterFSDPBackwardFunction function.
"""
ctx.post_backward = post_backward
return inputs
@staticmethod
def backward(ctx, *grads: torch.Tensor):
"""
Backward pass of the RegisterFSDPBackwardFunction function.
"""
ctx.post_backward()
return (None,) + grads
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import functools
import gc
import inspect
import logging
import math
import traceback
import warnings
from collections import defaultdict, namedtuple
from contextlib import ExitStack, nullcontext
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple
import torch
from torch.distributed import _coalescing_manager
from megatron.core import parallel_state
from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
from megatron.core.fp8_utils import is_float8tensor, modify_underlying_storage, quantize_param_shard
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.utils import is_submodule, is_te_min_version, log_on_each_pipeline_stage
try:
from transformer_engine.pytorch import fp8_model_init
except:
pass
try:
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
except:
pass
try:
import apex.contrib.nccl_allocator as nccl_allocator
except ImportError:
nccl_allocator = None
NCCL_MEMORY_POOL = None
logger = logging.getLogger(__name__)
def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
"""Alternate to ``assert`` when in the backward context to print the error
message ``s`` since otherwise, it is swallowed.
"""
if not cond:
print(s)
traceback.print_stack()
if raise_assertion_error:
raise AssertionError(s)
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None:
"""
Allocate storage for ``tensor`` with the given size.
Returns:
bool: ``True`` if this method allocated storage and ``False`` if the
storage was already allocated.
"""
with torch.no_grad():
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
already_allocated = tensor._typed_storage()._size() == size.numel()
if not already_allocated:
tensor_storage_size = tensor._typed_storage()._size()
_p_assert(
tensor_storage_size == 0,
"Tensor storage should have been resized to be 0 but got PLACEHOLDEr",
)
tensor._typed_storage()._resize_(size.numel())
def _free_storage(tensor: torch.Tensor):
"""
Frees the underlying storage of ``tensor``.
Returns:
bool: ``True`` if the method freed the storage and ``False`` if the
storage was already freed.
"""
with torch.no_grad():
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
already_freed = tensor._typed_storage()._size() == 0
if not already_freed:
_p_assert(
tensor.storage_offset() == 0,
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
f"storage offset: {tensor.storage_offset()}\n"
f"storage size: {tensor._typed_storage()._size()}\n"
f"tensor shape: {tensor.shape}",
)
tensor._typed_storage()._resize_(0)
TensorItemIndex = namedtuple(
'TensorItemIndex', ['global_data_index', 'size', 'item_id', 'bucket_id', 'shape']
)
BucketIndex = namedtuple('BucketIndex', ['bucket_id', 'global_data_index', 'size', 'items'])
ShardBucketIndex = namedtuple(
'ShardBucketIndex',
['bucket_id', 'global_data_index', 'local_data_index', 'bucket_data_index', 'size'],
)
class DualUBRAllocator:
"""
A custom allocator class that registers a single memory pool with two different
communication groups, which is not natively supported by apex's nccl_allocator.
This is particularly useful for Mixture of Experts (MoE) models where:
- Non-expert parameters/gradients use the data-parallel + context-parallel group (dp_cp_group)
- Expert parameters/gradients use the expert-parallel + data-parallel group (ep_dp_group)
Since Megatron-Core FSDP uses a contiguous single tensor for the entire model's parameters, we
need to register the same memory pool with both communication groups to enable nccl algorithms
that is relying on the user buffer registration for both expert and non-expert parameters.
Implementation:
It uses apex nccl_allocator internally to create a Tensor using ncclMemAlloc
and register to the `group` and then registers the Mempool also for the `additional_group`
Example:
```
import apex.contrib.nccl_allocator as nccl_allocator
nccl_allocator.init()
pool = nccl_allocator.create_nccl_mem_pool()
group_1 = torch.distributed.new_group(ranks=[0, 1, 2, 3, 4, 5, 6, 7], backend="nccl")
group_2 = torch.distributed.new_group(ranks=[0, 2, 4, 6], backend="nccl")
with DualUBRAllocator(pool, group_1, group_2):
a = torch.zeros(1024, dtype=torch.float32, device="cuda")
b = torch.zeros(1024, dtype=torch.float32, device="cuda")
```
"""
def __init__(
self,
pool, # torch.cuda.MemPool
group, # torch.distributed.ProcessGroup
additional_group, # torch.distributed.ProcessGroup
):
self.pool = pool
self.group = group
self.additional_group = additional_group
self.mem_allocator = nccl_allocator.nccl_mem(self.pool, group=self.group)
def __enter__(self):
backend = self.additional_group._get_backend(
torch.device("cuda", torch.cuda.current_device())
)
try:
# Since the registration is done in mempool granularity, we need to deregister
# the tensors in the mempool and re-register the mempool including the newly created
# tensors after the context is exited.
backend.deregister_mem_pool(self.pool)
except RuntimeError:
pass
self.mem_allocator.__enter__()
def __exit__(self, *args):
self.mem_allocator.__exit__(*args)
backend = self.additional_group._get_backend(
torch.device("cuda", torch.cuda.current_device())
)
backend.register_mem_pool(self.pool)
@dataclasses.dataclass
class BucketingPolicy:
"""
A policy for bucketing in Fully Sharded Data Parallel (FSDP) training.
Attributes:
suggested_bucket_size (int): The suggested size of each bucket in num of elements.
fsdp_unit_modules (list): A list of module classes that are treated as a
single unit for FSDP bucketing.
data_parallel_sharding_strategy (str): The strategy used for sharding
data parallel modules.
Note:
This policy is used to configure the bucketing behavior in FSDP training.
"""
suggested_bucket_size: Optional[int] = 40_000_000
fsdp_unit_modules: List[torch.nn.Module] = dataclasses.field(default_factory=list)
data_parallel_sharding_strategy: str = 'no_shard'
def _pad(number_to_be_padded: int, divisor: int) -> int:
return int(math.ceil(number_to_be_padded / divisor) * divisor)
def build_data_parallel_buffer_index(
elements: List[torch.Size],
data_parallel_rank: int,
data_parallel_world_size: int,
is_data_distributed: bool,
ddp_config: DistributedDataParallelConfig,
bucket_id: int = 0,
) -> Tuple[List[TensorItemIndex], BucketIndex, ShardBucketIndex]:
"""
Assuming that all input tensor elements are consecutively compose a global
buffer, give the index range of every tensor, every bucket and every in
bucket local buffer.
Args:
elements (List[torch.Size]): List of input tensor.
data_parallel_rank (int): Rank of the current process in the data parallel group.
data_parallel_world_size (int): World size of the data parallel group.
bucket_id (int, optional): The id of the bucket. Defaults to 0.
Returns:
Tuple[List[TensorItemIndex], BucketIndex, ShardBucketIndex]: The index
range of every tensor, every bucket and every in bucket local buffer.
"""
def _pad_if_needed(data_index: int) -> int:
"""
Pads data indices if using distributed optimizer (to ensure uniform sharding).
"""
if ddp_config.data_parallel_sharding_strategy != 'no_shard':
# Workaround for TE bug causing cuBLAS to pick an incompatible algorithm.
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
return _pad(data_index, math.lcm(data_parallel_world_size, 128))
return data_index
def add_item(item_id, item, bucket, item_index_map, bucket_id):
bucket.append(item)
bucket_size = sum([it.numel() for it in bucket])
item_index_map.append(
TensorItemIndex(
data_index + bucket_size - item.numel(),
item.numel(),
item_id=item_id,
bucket_id=bucket_id,
shape=item,
)
)
item_index_map = []
bucket = []
data_index = 0
for item_id, item in enumerate(elements):
add_item(item_id, item, bucket, item_index_map, bucket_id)
bucket_size = sum([it.numel() for it in bucket])
bucket_size = _pad_if_needed(bucket_size)
bucket_index = BucketIndex(
bucket_id,
data_index,
bucket_size,
items=list(filter(lambda x: x.bucket_id == bucket_id, item_index_map)),
)
shard_size = bucket_index.size // data_parallel_world_size
bucket_data_index = shard_size * data_parallel_rank
global_data_index = bucket_index.global_data_index + bucket_data_index
if is_data_distributed:
shard_bucket_index = ShardBucketIndex(
bucket_id, global_data_index, 0, bucket_data_index, shard_size
)
else:
shard_bucket_index = ShardBucketIndex(
bucket_id, global_data_index, global_data_index, bucket_data_index, shard_size
)
return item_index_map, bucket_index, shard_bucket_index
@dataclasses.dataclass
class Bucket:
"""
A container for holding data in Fully Sharded Data Parallel (FSDP) training.
Attributes:
data (torch.Tensor): A tensor containing the data elements
grouped together in a bucket.
data_operation_event (Optional[torch.cuda.Event]): An optional CUDA event
used to synchronize data operations.
status (Any): An optional status object used to track the state of the bucket.
Note:
Buckets are used to optimize communication in FSDP training by
grouping small tensors together.
"""
data: torch.Tensor
data_operation_event: Optional[torch.cuda.Event] = None
status: Any = None
class TemporaryBucketAllocator:
"""
A utility class for managing temporary buckets (buffers) used in FSDP
operations like parameters unshard and gradients reduction.
This allocator handles the dynamic allocation and deallocation of temporary memory buffers
needed during FSDP (Fully Sharded Data Parallel) operations, particularly for parameters
unshard and gradients reduction. It helps optimize memory usage by allowing temporary
buckets to be released when no longer needed.
Key Features:
- Dynamic allocation of temporary buckets for FSDP operations
- Memory-efficient management of temporary buffers
- Support for both parameters unshard and gradients reduction operations
- Automatic cleanup of unused buckets to save memory
Usage:
```python
# Create an allocator instance
allocator = TemporaryBucketAllocator(name="gpt_parameters")
# Allocate a temporary bucket
temp_bucket = allocator.allocate(size=1024, dtype=torch.float32)
# Use the temporary bucket for FSDP operations
# ... perform all-gather or reduce-scatter ...
# Free the bucket when done
allocator.free(temp_bucket)
```
Note:
It's important to release temporary buckets after use to prevent memory leaks
and optimize memory usage during training.
"""
def __init__(self):
self.buckets = {}
def allocate(
self,
bucket_id: int,
size: int,
dtype: torch.dtype,
device: torch.device,
mem_alloc_context: Optional[Callable] = None,
) -> Bucket:
"""
allocate a temporary bucket.
"""
if bucket_id not in self.buckets:
self.buckets[bucket_id] = Bucket(data=torch.empty(size, dtype=dtype, device=device))
return self.buckets[bucket_id]
def free(self, bucket_id: int):
"""
free a temporary bucket.
"""
if bucket_id in self.buckets:
_free_storage(self.buckets[bucket_id].data)
del self.buckets[bucket_id]
class StorageResizeBasedBucketAllocator(TemporaryBucketAllocator):
"""
A specialized temporary bucket allocator that resizes the storage of temporary buckets
based on the required size.
"""
def __init__(self):
self.buckets = {} # {bucket_id: Bucket}
def allocate(
self,
bucket_id: int,
size: int,
dtype: torch.dtype,
device: torch.device,
mem_alloc_context: Optional[Callable] = None,
) -> Bucket:
"""
allocate a temporary bucket.
"""
if bucket_id not in self.buckets:
self.buckets[bucket_id] = Bucket(data=torch.empty(size, dtype=dtype, device=device))
bucket = self.buckets[bucket_id]
_alloc_storage(bucket.data, torch.Size([size]))
return bucket
def free(self, bucket_id: int):
"""
free a temporary bucket.
"""
if bucket_id in self.buckets:
_free_storage(self.buckets[bucket_id].data)
class RotaryBucketAllocator(TemporaryBucketAllocator):
"""A specialized temporary bucket allocator that implements a circular buffer recycling strategy
to minimize memory fragmentation in FSDP operations.
RotaryBucketAllocator extends TemporaryBucketAllocator by maintaining a limited pool of
pre-allocated buffers that are reused in a circular manner. This approach helps prevent
memory fragmentation that typically occurs with frequent allocation and deallocation of
temporary buffers during FSDP operations.
Key Features:
- Circular buffer recycling strategy for memory efficiency
- Reduced memory fragmentation compared to dynamic allocation
- Pre-allocated buffer pool for faster access
- Automatic buffer reuse without explicit deallocation
Usage:
```python
# Create a rotary allocator
allocator = RotaryBucketAllocator(name="gpt_parameters")
# Get a temporary buffer from the pool
temp_bucket = allocator.allocate(dtype=torch.float32)
# Use the temporary bucket for FSDP operations
# ... perform all-gather or reduce-scatter ...
# Free the bucket when done, make it in idle buffer pool
allocator.free(temp_bucket)
```
"""
def __init__(self, name: str):
self.name = name
self.num_global_buffer = 0
self.idle_buffer = [] # [buffer_id]
self.using_buffer = {} # {bucket_id: buffer_id}
self.buckets = {}
def allocate(
self,
bucket_id: int,
size: int,
dtype: torch.dtype,
device: torch.device,
mem_alloc_context: Optional[Callable] = None,
) -> Bucket:
"""
allocate a temporary bucket.
"""
def _get_global_buffer(buffer_id: int):
return parallel_state.get_global_memory_buffer().get_tensor(
[size],
dtype=dtype,
name=self._get_gbuf_name(buffer_id),
mem_alloc_context=mem_alloc_context,
)
if bucket_id in self.using_buffer:
buffer_id = self.using_buffer[bucket_id]
return Bucket(data=_get_global_buffer(buffer_id))
if len(self.idle_buffer) == 0:
# allocate new buffer
buffer_id = self.num_global_buffer
self.num_global_buffer += 1
self.idle_buffer.append(buffer_id)
buffer_id = self.idle_buffer.pop(0)
self.using_buffer[bucket_id] = buffer_id
return Bucket(data=_get_global_buffer(buffer_id))
def _get_gbuf_name(self, buffer_id: int):
return f"{self.name}_{buffer_id}"
def free(self, bucket_id: int):
"""
free a temporary bucket.
"""
if bucket_id in self.using_buffer:
buffer_id = self.using_buffer.pop(bucket_id)
self.idle_buffer.append(buffer_id)
class FixedPoolAllocator(TemporaryBucketAllocator):
"""
A specialized temporary bucket allocator that implements a buffer recycling strategy
to minimize memory fragmentation in FSDP operations.
This allocator maintains a fixed pool of pre-allocated buffers, reusing them
to reduce the overhead and fragmentation caused by frequent allocation and
deallocation of temporary buffers during FSDP operations.
"""
def __init__(self, name: str, fsdp_param_groups: List["ParameterGroup"], size: int = 2):
self.name = name
self.fsdp_param_groups = fsdp_param_groups
self.size = size # Number of buffers in the pool (default is 2 for double buffering)
self.allocation_tracker = {} # tracking the global buffer allocation status
# Build a mapping from FSDP unit id to its associated bucket ids.
fsdp_unit_buckets = defaultdict(list)
for bucket_id, param_group in enumerate(fsdp_param_groups):
if param_group.fsdp_unit_id == -1 or param_group.fsdp_unit_id is None:
continue
fsdp_unit_buckets[param_group.fsdp_unit_id].append(bucket_id)
self.fsdp_unit_buckets = fsdp_unit_buckets
# Identify the largest group of FSDP units that share the same buffer storage.
fsdp_units_to_double_buffer = []
for fsdp_unit_id, bucket_ids in fsdp_unit_buckets.items():
same_storage_fsdp_units = []
for i in fsdp_unit_buckets:
if self._is_two_bucket_group_equal(fsdp_unit_buckets[i], bucket_ids):
same_storage_fsdp_units.append(i)
# Track the largest group of FSDP units sharing the same buffer storage
if len(same_storage_fsdp_units) > len(fsdp_units_to_double_buffer):
fsdp_units_to_double_buffer = same_storage_fsdp_units
# --- Fixed Pool Buffering Check ---
# Ensure there is at least one group of FSDP units eligible for fixed pool buffering.
# If not, the allocator cannot provide its intended memory recycling benefits.
assert (
len(fsdp_units_to_double_buffer) > 0
), "Found no FSDP units to use fixed-size buffering"
self.fsdp_double_buffer_units = fsdp_units_to_double_buffer
# Initialize buffer group status.
# Each buffer group represents a set of buffers associated with an FSDP unit's bucket group.
self.idle_buffer = [] # List of available (buf_group_id, offset) tuples.
self.using_buffer = {} # Map from bucket_id to (buf_group_id, offset) in use.
# Populate the idle buffer pool with all buffer group and bucket offset combinations.
for buf_group_id in range(self.size): # Iterate over each buffer group in the pool.
num_bucket = len(self.fsdp_unit_buckets[self.fsdp_double_buffer_units[0]])
for bucket_offset in range(num_bucket):
self.idle_buffer.append((buf_group_id, bucket_offset))
# Fallback allocator used if the fixed pool allocator cannot fulfill a request.
self.backup_allocator = TemporaryBucketAllocator()
def _is_two_bucket_group_equal(self, group_a, group_b):
# Check if two bucket groups are equivalent in dtype and size.
if len(group_a) != len(group_b):
return False
for a, b in zip(group_a, group_b):
pg_a = self.fsdp_param_groups[a]
pg_b = self.fsdp_param_groups[b]
a_size = sum(p.numel() for p in pg_a.params)
b_size = sum(p.numel() for p in pg_b.params)
if pg_a.dtype != pg_b.dtype or a_size != b_size:
return False
return True
def allocate(
self,
bucket_id: int,
size: int,
dtype: torch.dtype,
device: torch.device,
mem_alloc_context: Optional[Callable] = None,
) -> Bucket:
"""
allocate a temporary bucket.
"""
fsdp_unit_id = self.fsdp_param_groups[bucket_id].fsdp_unit_id
if fsdp_unit_id in self.fsdp_double_buffer_units:
# Try to allocate from the buffer pool.
bucket_offset = self.fsdp_unit_buckets[fsdp_unit_id].index(bucket_id)
buffer_name = None
if bucket_id in self.using_buffer:
# If this bucket is already using a buffer, reuse it.
buf_group_id, bucket_offset = self.using_buffer[bucket_id]
buffer_name = self._get_gbuf_name(buf_group_id, bucket_offset)
else:
# Otherwise, find an available buffer group for this bucket offset.
for buf_group_id in range(self.size):
if (buf_group_id, bucket_offset) in self.idle_buffer:
self.using_buffer[bucket_id] = (buf_group_id, bucket_offset)
buffer_name = self._get_gbuf_name(buf_group_id, bucket_offset)
self.idle_buffer.remove((buf_group_id, bucket_offset))
break
assert buffer_name is not None, (
f"[FSDP][Rank {torch.distributed.get_rank()}][{self.name}]"
f"No buffer found for bucket_id: {bucket_id}, fsdp_unit_id: {fsdp_unit_id}, "
f"bucket_offset: {bucket_offset} \n"
f"current using_buffer: {self.using_buffer} \n"
f"current idle_buffer: {self.idle_buffer}"
)
# Synchronization is required before the allocation for the user buffer
if mem_alloc_context is not None and mem_alloc_context != nullcontext:
# Check if a new buffer allocation is required
if (
self.allocation_tracker.get((buffer_name, dtype), None) is None
or self.allocation_tracker[(buffer_name, dtype)] < size
):
# Requires synchronization for new buffer allocation
self.allocation_tracker[(buffer_name, dtype)] = size
torch.cuda.synchronize()
return Bucket(
data=parallel_state.get_global_memory_buffer().get_tensor(
[size], dtype=dtype, name=buffer_name, mem_alloc_context=mem_alloc_context
)
)
# If the bucket is not eligible for fixed pool buffering, or no buffer is available,
# fall back to dynamic allocation via the backup allocator. This means that we
# will do dynamic memory allocation.
logging.debug(f"[FSDP] Using backup allocator for {bucket_id} {fsdp_unit_id}")
return self.backup_allocator.allocate(
bucket_id=bucket_id, size=size, dtype=dtype, device=device
)
def _get_gbuf_name(self, buf_group_id: int, bucket_index: int):
return f"{self.name}_{buf_group_id}_{bucket_index}"
def free(self, bucket_id: int):
"""
free a temporary bucket.
"""
fsdp_unit_id = self.fsdp_param_groups[bucket_id].fsdp_unit_id
if fsdp_unit_id in self.fsdp_double_buffer_units:
if bucket_id not in self.using_buffer:
# This bucket is not allocated by fixed pool allocator.
return
# Return the buffer to the idle pool.
self.idle_buffer.append(self.using_buffer[bucket_id])
del self.using_buffer[bucket_id]
return
# If not managed by fixed pool allocator, delegate to the backup allocator.
logging.debug(f"[FSDP] Free from the backup allocator for {bucket_id} {fsdp_unit_id}")
self.backup_allocator.free(bucket_id)
class DataParallelBuffer:
"""
A class that manages the data parallel buffer for Fully Sharded Data Parallel (FSDP) training.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
params: List[torch.nn.Parameter],
is_data_distributed: bool,
bucket_id: int,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
inter_data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
temporary_bucket_allocator: Optional[TemporaryBucketAllocator] = None,
init_meta_only: bool = False,
is_dtype_float8: bool = False,
gradient_scaling_factor: Optional[float] = None,
mem_alloc_context: Optional[Callable] = None,
) -> None:
self.ddp_config = ddp_config
self.params = params
_param_dtype = {p.dtype for p in self.params}
assert len(_param_dtype) == 1, f'params have different dtypes: {_param_dtype}'
self.is_data_distributed = is_data_distributed
self.bucket_id = bucket_id
self.dtype = dtype if dtype else next(iter(_param_dtype))
self.device = device
self.data_parallel_group = data_parallel_group
self.inter_data_parallel_group = inter_data_parallel_group
self.dp_rank = self.data_parallel_group.rank()
self.dp_world_size = self.data_parallel_group.size()
self.temporary_bucket_allocator = (
temporary_bucket_allocator if temporary_bucket_allocator else TemporaryBucketAllocator()
)
self.is_dtype_float8 = is_dtype_float8
self.gradient_scaling_factor = gradient_scaling_factor
self.mem_alloc_context = mem_alloc_context if mem_alloc_context else nullcontext
(self.item_index_map, self.bucket_index, self.shard_bucket_index) = (
build_data_parallel_buffer_index(
[p.shape for p in self.params],
self.dp_rank,
self.dp_world_size,
is_data_distributed,
ddp_config,
bucket_id=bucket_id,
)
)
self.data_size = (
self.bucket_index.size if not is_data_distributed else self.shard_bucket_index.size
)
if init_meta_only:
self.data = None
else:
self.data = torch.empty(self.data_size, dtype=self.dtype, device=device)
self.param_idx = {p: i for i, p in enumerate(self.params)}
self.placeholder_bucket = None
self.placeholder_items = {}
def fetch_bucket(
self, dtype: Optional[torch.dtype] = None, and_allocate_params_data: bool = False
) -> Bucket:
"""
Fetch a communication buffer for data-parallel operations.
The size of the bucket is defined by the `DataParallelBuffer` instance.
If `and_allocate_params_data` is True, this method resets the parameter
data stored in the `DataParallelBuffer` instance.
Args:
dtype (Optional[torch.dtype], optional): The data type of the tensor
to fetch a buffer for. Defaults to None.
and_allocate_params_data (bool, optional): Whether to allocate and
reset parameter data. Defaults to False.
Returns:
Bucket: The communication buffer for the specified data type.
"""
if dtype is None:
dtype = self.dtype
bucket_index = self.bucket_index
if not self.is_data_distributed and dtype == self.dtype:
bucket = Bucket(
data=self.data[
bucket_index.global_data_index : bucket_index.global_data_index
+ bucket_index.size
]
)
else:
bucket = self.temporary_bucket_allocator.allocate(
bucket_id=bucket_index.bucket_id,
size=bucket_index.size,
dtype=dtype,
device=self.device,
mem_alloc_context=self.mem_alloc_context,
)
if and_allocate_params_data:
for p in self.params:
item_id = self.param_idx[p]
if is_float8tensor(p):
p._data = self.get_item_from_bucket(bucket, item_id).view(p.shape)
else:
p.data = self.get_item_from_bucket(bucket, item_id).view(p.shape)
return bucket
def free_bucket_storage(self, and_free_params_data: bool = False):
"""
Release the storage of a temporary communication bucket.
If the bucket is temporary, this method frees its storage.
If `and_free_params_data` is True, this method also releases the storage
of the parameter data stored in the `DataParallelBuffer` instance.
Args:
and_free_params_data (bool, optional): Whether to also release the
storage of the parameter data. Defaults to False.
Returns:
None
"""
if not self.is_data_distributed:
return
self.temporary_bucket_allocator.free(self.bucket_index.bucket_id)
if and_free_params_data:
if self.placeholder_bucket is None:
self.placeholder_bucket = Bucket(
data=torch.empty(self.bucket_index.size, dtype=self.dtype, device=self.device)
)
for p in self.params:
item_id = self.param_idx[p]
self.placeholder_items[item_id] = self.get_item_from_bucket(
self.placeholder_bucket, item_id
).view(p.shape)
_free_storage(self.placeholder_bucket.data)
for p in self.params:
item_id = self.param_idx[p]
if is_float8tensor(p):
p._data = self.placeholder_items[item_id]
else:
p.data = self.placeholder_items[item_id]
def _get_item_slice_in_shard(self, item_id: int) -> Tuple[int, int]:
item_index = self.item_index_map[item_id]
shard_bucket_index = self.shard_bucket_index
item_global_start = item_index.global_data_index
item_global_end = item_index.global_data_index + item_index.size
shard_bucket_start = shard_bucket_index.global_data_index
shard_bucket_end = shard_bucket_index.global_data_index + shard_bucket_index.size
if item_global_start > shard_bucket_end or item_global_end < shard_bucket_start:
return (0, 0)
start = max(item_global_start, shard_bucket_start) - item_global_start
end = min(item_global_end, shard_bucket_end) - item_global_start
return (start, end)
# pylint: disable=missing-function-docstring
def locate_item_in_global_item(self, item_id: int) -> Tuple[int, int]:
item_index = self.item_index_map[item_id]
if not self.is_data_distributed:
return (0, item_index.size)
slice_start, slice_end = self._get_item_local_shard_index(item_id)
if slice_start == slice_end:
return (0, 0)
local_shard_index_to_global_index_offset = (
self.shard_bucket_index.global_data_index - self.shard_bucket_index.local_data_index
)
slice_start += local_shard_index_to_global_index_offset
slice_end += local_shard_index_to_global_index_offset
return (
slice_start - item_index.global_data_index,
slice_end - item_index.global_data_index,
)
def _get_item_local_shard_index(self, item_id: int) -> Tuple[int, int]:
slice_start, slice_end = self._get_item_slice_in_shard(item_id)
if slice_start == slice_end:
return (0, 0)
item_index = self.item_index_map[item_id]
shard_bucket_index = self.shard_bucket_index
offset = (
item_index.global_data_index
- shard_bucket_index.global_data_index
+ shard_bucket_index.local_data_index
)
return (offset + slice_start, offset + slice_end)
def _get_item_local_index(self, item_id: int) -> Tuple[int, int]:
if not self.is_data_distributed:
item_index = self.item_index_map[item_id]
return (item_index.global_data_index, item_index.global_data_index + item_index.size)
return self._get_item_local_shard_index(item_id)
def set_item(self, item_id: int, item: torch.Tensor) -> None:
"""
Update a tensor item managed by the `DataParallelBuffer` instance.
The storage of the item is mapped to the communication bucket.
This method updates the item data and ensures consistency with the bucket.
Args:
item_id (int): The ID of the tensor item to update.
item (torch.Tensor): The original tensor to be put into the buffer.
Returns:
None
"""
if is_float8tensor(item):
item_data = item._data
else:
item_data = item.data
if self.is_data_distributed:
slice_start, slice_end = self._get_item_slice_in_shard(item_id)
item_data = item_data.flatten()[slice_start:slice_end]
local_index_start, local_index_end = self._get_item_local_index(item_id)
shard = self.data[local_index_start:local_index_end]
if shard.numel() > 0:
shard.data.copy_(item_data.flatten())
def get_item(self, item_id: int, only_shard: bool = False) -> torch.Tensor:
"""
Retrieve a tensor item managed by the `DataParallelBuffer` instance.
The storage of the item is mapped to the communication bucket.
If `only_shard` is True, returns only the shard of the item corresponding
to the current process.
Otherwise, returns the entire item.
Args:
item_id (int): The ID of the tensor item to retrieve.
only_shard (bool, optional): Whether to return only the shard of the
item. Defaults to False.
Returns:
torch.Tensor: The retrieved tensor item.
"""
if only_shard:
start, end = self._get_item_local_shard_index(item_id)
else:
start, end = self._get_item_local_index(item_id)
return self.data[start:end]
def get_item_from_bucket(self, bucket: Bucket, item_id: int):
"""get item from bucket."""
item_index = self.item_index_map[item_id]
bucket_index = self.bucket_index
start_index = item_index.global_data_index - bucket_index.global_data_index
end_index = start_index + item_index.size
item = bucket.data[start_index:end_index]
return item
def get_shard_from_bucket(self, bucket: Bucket):
"""Get the local sharding of the bucket."""
shard_bucket_index = self.shard_bucket_index
offset = shard_bucket_index.bucket_data_index
shard_size = shard_bucket_index.size
shard = bucket.data[offset : offset + shard_size]
return shard
def get_shard_from_local_buffer(self) -> torch.Tensor:
"""Get the local sharding of the bucket."""
index = self.shard_bucket_index
return self.data[index.local_data_index : index.local_data_index + index.size]
@dataclasses.dataclass
class ParameterGroup:
"""
A group of model parameters with associated metadata for data-parallel training.
This dataclass encapsulates a list of PyTorch parameters and additional information
necessary for managing data-parallel operations, such as data type, gradient requirements,
and buffer assignments.
"""
params: List[torch.nn.Parameter]
dtype: Optional[torch.dtype] = None
is_expert_param: bool = False
requires_grad: Optional[bool] = None
fsdp_unit_id: Optional[int] = None
data_parallel_world_size: Optional[int] = None
model_weight_buffer: Optional[DataParallelBuffer] = None
main_weight_buffer: Optional[DataParallelBuffer] = None
main_grad_buffer: Optional[DataParallelBuffer] = None
def _get_parameter_groups(
module: torch.nn.Module,
policy: BucketingPolicy,
meta_device_init_fp8_params: dict,
bucket_group_by_fsdp_unit: bool = True,
):
"""
Get the parameter group for the given module and parameters.
"""
param_to_name = {p: name for name, p in module.named_parameters()}
fsdp_units = []
if policy.fsdp_unit_modules:
param_to_id = {}
for i, p in enumerate(module.parameters()):
param_to_id[p] = i
fsdp_modules = []
for m in module.modules():
# Skip nested FSDP module.
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
continue
if isinstance(m, tuple(policy.fsdp_unit_modules)):
fsdp_units.append([param_to_name[p] for p in m.parameters()])
fsdp_modules.append(m)
def _does_param_require_new_bucket(param):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return (
getattr(param, "shared_embedding", False)
and policy.data_parallel_sharding_strategy != "no_shard"
)
is_expert_parameter = lambda p: not getattr(p, 'allreduce', True)
# Step 1: Group the parameters according to their execution order and attributes.
parameter_groups = []
for name, param in module.named_parameters():
param_attrs = dict(
dtype=(
"float8"
if is_float8tensor(param) or meta_device_init_fp8_params.get(name, False)
else param.dtype
),
is_expert_param=is_expert_parameter(param),
requires_grad=param.requires_grad,
fsdp_unit_id=None,
)
for fsdp_unit_id, fsdp_unit in enumerate(fsdp_units):
if name in fsdp_unit:
param_attrs["fsdp_unit_id"] = fsdp_unit_id
break
found_group = False
for param_group in parameter_groups:
group_attrs = {
key: value for key, value in param_group.__dict__.items() if key in param_attrs
}
if group_attrs == param_attrs:
param_group.params.append(param)
found_group = True
break
if not found_group:
parameter_groups.append(ParameterGroup([param], **param_attrs))
# Step 2: Bucket the parameters based on the guide bucket size.
suggested_bucket_size = policy.suggested_bucket_size
bucket_groups = []
for group in parameter_groups:
bucket = []
basic_attrs = {
key: value
for key, value in group.__dict__.items()
if key in ['dtype', 'is_expert_param', 'requires_grad', 'fsdp_unit_id']
}
for param in group.params:
if _does_param_require_new_bucket(param):
if len(bucket) > 0:
bucket_groups.append(ParameterGroup(bucket, **basic_attrs))
bucket_groups.append(ParameterGroup([param], **basic_attrs))
bucket = []
continue
bucket.append(param)
if (
group.fsdp_unit_id is None
and suggested_bucket_size
and sum([p.numel() for p in bucket]) >= suggested_bucket_size
):
bucket_groups.append(ParameterGroup(bucket, **basic_attrs))
bucket = []
continue
if bucket:
bucket_groups.append(ParameterGroup(bucket, **basic_attrs))
param_to_param_group = {}
for group_id, group in enumerate(bucket_groups):
for param in group.params:
param_to_param_group[param] = group_id
# Generate the groups of collective buckets, where each group aggregates
# the collectives per FSDP unit. This improves performance by reducing
# the number of collective calls and increasing per-collective efficiency.
#
# Set default aggregate buckets of bucket.
bucket_to_bucket_group = {}
for bucket_id in range(len(bucket_groups)):
bucket_to_bucket_group[bucket_id] = [bucket_id]
# Set aggregate buckets by FSDP units.
if bucket_group_by_fsdp_unit:
bucket_group_map = {}
for bucket_id, param_group in enumerate(bucket_groups):
if param_group.fsdp_unit_id is None:
continue
id = (param_group.fsdp_unit_id, param_group.is_expert_param)
if id not in bucket_group_map:
bucket_group_map[id] = []
bucket_group_map[id].append(bucket_id)
for bucket_group in bucket_group_map.values():
for bucket_id in bucket_group:
bucket_to_bucket_group[bucket_id] = bucket_group
return (bucket_groups, param_to_param_group, bucket_to_bucket_group)
class ParamAndGradBuffer:
"""A class that manages parameter grouping, buffer allocation, and
communication operations for data-parallel distributed training.
This class provides functionality to:
1. Group parameters based on their data types and communication group sizes
2. Create contiguous buffers for model weights, gradients, and high-precision
main weights
3. Handle parameter unsharding, gradient reduction, and weight
synchronization operations
Key Features:
- Efficient parameter grouping based on data types and communication patterns
- Memory-efficient contiguous buffer allocation
- Support for mixed-precision training with main weights
- Distributed operations including parameters all-gather and gradients
reduce-scatter/all-reduce
- Synchronized weight updates between model and main weights
Note:
This class is designed for distributed training scenarios where efficient
parameter management and communication are crucial for performance.
Args:
ddp_config (DistributedDataParallelConfig): The distributed data parallel
configuration.
module (torch.nn.Module): The module whose parameters are to be grouped
and flatten.
bucketing_policy (BucketingPolicy): The bucketing policy.
data_parallel_group (torch.distributed.ProcessGroup): The data parallel group.
expert_data_parallel_group (Optional[torch.distributed.ProcessGroup]):
The expert data parallel group.
preserve_fp32_weights (bool): Whether to preserve FP32 weights.
grad_reduce_in_fp32 (bool): Whether to reduce gradients in FP32.
gradient_scaling_factor (Optional[float]): The gradient scaling factor.
expert_gradient_scaling_factor (Optional[float]): The expert gradient
scaling factor.
device (torch.device): The parameter and gradient buffer device.
only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad (bool):
Whether to only create the gradient buffer and main weight buffer
for parameters that require gradients. Default is True.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
bucketing_policy: BucketingPolicy,
data_parallel_group: torch.distributed.ProcessGroup,
expert_data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
inter_data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
preserve_fp32_weights: bool = True,
grad_reduce_in_fp32: bool = True,
gradient_scaling_factor: Optional[float] = None,
expert_gradient_scaling_factor: Optional[float] = None,
device: torch.device = torch.device('cuda'),
only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad: bool = True,
reset_parameters_for_meta_device_init_module: bool = False,
):
self.ddp_config = ddp_config
self.module = module
self.bucketing_policy = bucketing_policy
self.param_to_name = {p: name for name, p in self.module.named_parameters()}
self.preserve_fp32_weights = preserve_fp32_weights
self.grad_reduce_in_fp32 = grad_reduce_in_fp32
self.data_parallel_group = data_parallel_group
self.expert_data_parallel_group = expert_data_parallel_group
self.inter_data_parallel_group = inter_data_parallel_group
self.params = list(module.parameters())
self.gradient_scaling_factor = gradient_scaling_factor
self.expert_gradient_scaling_factor = expert_gradient_scaling_factor
self.device = device
self.only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad = (
only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad
)
self.reset_parameters_for_meta_device_init_module = (
reset_parameters_for_meta_device_init_module
)
# User buffer registration related settings
if self.ddp_config.nccl_ub:
# Since the user buffer registration requires (non-dynamic) persistent memory,
# it always uses fsdp double buffer.
self.ddp_config.fsdp_double_buffer = True
# Initialize the NCCL memory pool.
global NCCL_MEMORY_POOL
NCCL_MEMORY_POOL = nccl_allocator.create_nccl_mem_pool()
if torch.distributed.get_rank() == 0:
logging.info(
f"[Rank {torch.distributed.get_rank()}] Created NCCL memory pool for \
UserBuffer Registration"
)
logging.info(
f"[Rank {torch.distributed.get_rank()}] FSDP double buffer is enabled."
)
# If using nccl_ub, it returns a function that registers buffers to the NCCL memory pool
# Buffer is registered to data_parallel_group and expert_data_parallel_group if it exists
# In the case of not using nccl_ub, it returns a nullcontext
self.mem_alloc_context = self.get_mem_alloc_context(
group=self.data_parallel_group, additional_group=self.expert_data_parallel_group
)
# Mark fp8 param.
meta_device_init_fp8_params = {}
if reset_parameters_for_meta_device_init_module:
for m in module.modules():
if not isinstance(m, TransformerEngineBaseModule):
continue
for name, param in m.named_parameters(recurse=False):
# The fp8 param initialized from the meta device may NOT be
# an fp8 tensor, according to the internal logic of the TE
# to determine whether this parameter is fp8 or not.
fp8_meta_index = m.param_init_meta[name].fp8_meta_index
if m.primary_weights_in_fp8 and fp8_meta_index is not None:
meta_device_init_fp8_params[self.param_to_name[param]] = True
# Get the parameter groups.
(self.parameter_groups, self.param_to_param_group, self.bucket_to_bucket_group) = (
_get_parameter_groups(module, bucketing_policy, meta_device_init_fp8_params)
)
self._init_each_parameter_group_buffers(meta_device_init_fp8_params)
# Initialize the optimizer named parameters.
self.optimizer_named_parameters = self._init_optimizer_named_parameters()
self._log_parameter_groups()
def get_mem_alloc_context(self, group=None, additional_group=None):
"""
Get the memory allocation context for the parameter and gradient buffers.
"""
if self.ddp_config.nccl_ub:
assert nccl_allocator is not None, "NCCL allocator is not available."
global NCCL_MEMORY_POOL
if group is None:
# data parallel group is a default group for user buffer registration
group = self.data_parallel_group
if additional_group is None:
# register buffers to the default group directly using apex memory allocator
mem_alloc_context = functools.partial(
nccl_allocator.nccl_mem, NCCL_MEMORY_POOL, group=group
)
else:
# In case of MoE, we need to register buffer to both DP and EP communicator groups.
# Custom DualUBRAllocator class is used to register buffers to both groups.
# Register buffers to the data_parallel_group using apex memory allocator
# and register buffers to the expert_data_parallel_group.
assert group != additional_group, "Group and additional group must be different."
mem_alloc_context = functools.partial(
DualUBRAllocator,
NCCL_MEMORY_POOL,
group=group,
additional_group=additional_group,
)
return mem_alloc_context
else:
return nullcontext
def _log_parameter_groups(self):
"""
Log the parameter groups for all pipeline stages.
"""
# Log buckets for all PP stages.
if (
parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0
and parallel_state.get_tensor_model_parallel_rank() == 0
):
bucket_groups = self.parameter_groups
param_to_name = self.param_to_name
log_strs = []
log_strs.append(f'Number of parameter groups for FSDP: {len(bucket_groups)}')
for index, group in enumerate(bucket_groups):
numel = 0
for param in group.params:
numel += param.numel()
log_strs.append(
f"Params for group {index+1} ({numel} elements, dtype: {group.dtype}, "
f"fsdp_unit_id: {group.fsdp_unit_id}, "
f"has_weight_buffer: {group.model_weight_buffer is not None}, "
f"has_grad_buffer: {group.main_grad_buffer is not None}, "
f"has_main_weight_buffer: {group.main_weight_buffer is not None}):"
)
for param in group.params:
log_strs.append(f'\t{param_to_name[param]}')
log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs))
def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
"""
Initialize the buffers for each parameter group.
"""
data_parallel_sharding_strategy = self.ddp_config.data_parallel_sharding_strategy
if data_parallel_sharding_strategy == 'no_shard':
is_model_weight_buffer_distributed = False
is_main_weight_buffer_distributed = False
is_grad_buffer_distributed = False
elif data_parallel_sharding_strategy == 'optim':
is_model_weight_buffer_distributed = False
is_main_weight_buffer_distributed = True
is_grad_buffer_distributed = False
elif data_parallel_sharding_strategy == 'optim_grads':
is_model_weight_buffer_distributed = False
is_main_weight_buffer_distributed = True
is_grad_buffer_distributed = True
elif data_parallel_sharding_strategy == 'optim_grads_params':
is_model_weight_buffer_distributed = True
is_main_weight_buffer_distributed = True
is_grad_buffer_distributed = True
else:
raise ValueError(
f'Invalid data_parallel_sharding_strategy: {data_parallel_sharding_strategy}'
)
if self.ddp_config.nccl_ub:
assert self.ddp_config.fsdp_double_buffer, (
"NCCL UB is only supported with FSDP double buffer. "
"Please set fsdp_double_buffer=True in the ddp config."
)
if self.ddp_config.fsdp_double_buffer:
UB_BUFFER_NUM = 2
self.weight_alloc = FixedPoolAllocator(
name="fsdp_params", fsdp_param_groups=self.parameter_groups, size=UB_BUFFER_NUM
)
self.main_grad_alloc = FixedPoolAllocator(
name="fsdp_grads", fsdp_param_groups=self.parameter_groups, size=UB_BUFFER_NUM
)
self.double_buf_units = self.weight_alloc.fsdp_double_buffer_units
else:
self.weight_alloc = StorageResizeBasedBucketAllocator()
self.main_grad_alloc = None
self.buffer_all_in_one = True
preserve_fp32_weights = self.preserve_fp32_weights
grad_reduce_in_fp32 = self.grad_reduce_in_fp32
buffer_size = {torch.float32: 0, torch.float16: 0, torch.bfloat16: 0, "float8": 0}
for group_id, group in enumerate(self.parameter_groups):
dp_group = (
self.data_parallel_group
if not group.is_expert_param
else self.expert_data_parallel_group
)
group.data_parallel_world_size = dp_group.size()
gradient_scaling_factor = (
self.gradient_scaling_factor
if not group.is_expert_param
else self.expert_gradient_scaling_factor
)
one_param = group.params[0]
is_dtype_float8 = is_float8tensor(one_param) or meta_device_init_fp8_params.get(
self.param_to_name[one_param], False
)
if is_dtype_float8:
param_dtype = torch.uint8
grad_dtype = torch.bfloat16
else:
param_dtype = group.params[0].dtype
grad_dtype = param_dtype
should_create_grad_buffer_or_main_weight_buffer = (
not self.only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad
or group.requires_grad
)
# Initialize the model weight buffer.
if data_parallel_sharding_strategy != 'no_shard':
group.model_weight_buffer = DataParallelBuffer(
self.ddp_config,
group.params,
is_data_distributed=is_model_weight_buffer_distributed
and group.data_parallel_world_size > 1,
dtype=param_dtype,
device=self.device,
data_parallel_group=dp_group,
inter_data_parallel_group=self.inter_data_parallel_group,
init_meta_only=True,
is_dtype_float8=is_dtype_float8,
temporary_bucket_allocator=self.weight_alloc,
bucket_id=group_id,
mem_alloc_context=self.mem_alloc_context,
)
# Initialize the main weight buffer.
if should_create_grad_buffer_or_main_weight_buffer and preserve_fp32_weights:
group.main_weight_buffer = DataParallelBuffer(
self.ddp_config,
group.params,
is_data_distributed=is_main_weight_buffer_distributed
and group.data_parallel_world_size > 1,
dtype=torch.float32,
device=self.device,
data_parallel_group=dp_group,
inter_data_parallel_group=self.inter_data_parallel_group,
init_meta_only=True,
bucket_id=group_id,
mem_alloc_context=self.mem_alloc_context,
)
# Initialize the main grad buffer.
if should_create_grad_buffer_or_main_weight_buffer:
group.main_grad_buffer = DataParallelBuffer(
self.ddp_config,
group.params,
is_data_distributed=is_grad_buffer_distributed
and group.data_parallel_world_size > 1,
dtype=torch.float32 if grad_reduce_in_fp32 else grad_dtype,
device=self.device,
data_parallel_group=dp_group,
inter_data_parallel_group=self.inter_data_parallel_group,
init_meta_only=True,
is_dtype_float8=not grad_reduce_in_fp32 and grad_dtype is torch.uint8,
temporary_bucket_allocator=self.main_grad_alloc,
gradient_scaling_factor=gradient_scaling_factor,
bucket_id=group_id,
mem_alloc_context=self.mem_alloc_context,
)
if grad_reduce_in_fp32:
buffer_size[torch.float32] += group.main_grad_buffer.data_size
elif group.main_grad_buffer.is_dtype_float8:
buffer_size["float8"] += group.main_grad_buffer.data_size
else:
buffer_size[group.main_grad_buffer.dtype] += group.main_grad_buffer.data_size
reset_context_args = {"init_param_with_fp8": self.ddp_config.fp8_param_gather}
module_reset_flag = {}
if self.reset_parameters_for_meta_device_init_module:
self.param_to_direct_module = {}
for name, m in self.module.named_modules():
for p in m.parameters(recurse=False):
self.param_to_direct_module[p] = (name, m)
meta_params_numel = 0
cuda_params_numel = 0
cpu_params_numel = 0
for group in self.parameter_groups:
for p in group.params:
if p.is_meta:
meta_params_numel += p.numel()
elif p.device.type == 'cuda':
cuda_params_numel += p.numel()
else:
cpu_params_numel += p.numel()
log_str = (
f"Meta params numel: {meta_params_numel / 1_000_000:.2f} M, "
f"CUDA params numel: {cuda_params_numel / 1_000_000:.2f} M, "
f"CPU params numel: {cpu_params_numel / 1_000_000:.2f} M"
)
log_on_each_pipeline_stage(logger, logging.INFO, log_str)
# Initialize the model weight buffer data of each parameter group.
for group in self.parameter_groups:
wbuf = group.model_weight_buffer
if wbuf:
with self.mem_alloc_context():
wbuf.data = torch.empty(wbuf.data_size, dtype=wbuf.dtype, device=self.device)
bucket = wbuf.fetch_bucket()
mbuf = group.main_weight_buffer
if mbuf:
mbuf.data = torch.empty(mbuf.data_size, dtype=mbuf.dtype, device=self.device)
for item_id, p in enumerate(group.params):
if wbuf:
if self.reset_parameters_for_meta_device_init_module and p.is_meta:
m_name, m = self.param_to_direct_module[p]
if not module_reset_flag.get(m_name, False) and hasattr(
m, "reset_parameters"
):
old_params = list(m.parameters(recurse=False))
# If the GPU memory over threshold, empty cache to leave
# some memory for initialization of the model on the
# CUDA device.
if check_gpu_memory(threshold=0.5):
gc.collect()
torch.cuda.empty_cache()
m.to_empty(device=self.device, recurse=False)
if is_te_min_version("0.9.0") and not isinstance(
m, TransformerEngineBaseModule
):
reset_context_args["with_cuda_rng_tracker"] = True
with ResetParametersContext(**reset_context_args):
m.reset_parameters()
module_reset_flag[m_name] = True
new_params = list(m.parameters(recurse=False))
self._reset_parameters(old_params, new_params)
p = group.params[item_id]
# After resetting parameters, delete fp8 transpose cache
# if we do not need keep cache.
if not self.ddp_config.keep_fp8_transpose_cache_when_using_custom_fsdp:
for _param in m.parameters(recurse=False):
if is_float8tensor(_param):
_param._transpose_invalid = True
_param._transpose = None
assert not p.is_meta, (self.param_to_name[p], module_reset_flag)
wbuf.set_item(item_id, p)
# reset the parameter data to the buffer
new_param_data = wbuf.get_item_from_bucket(bucket, item_id).view(p.shape)
if is_float8tensor(p):
modify_underlying_storage(p, new_param_data)
else:
old_param_data = p.data
p.data = new_param_data
assert old_param_data._base is None
p.data.detach().copy_(old_param_data)
del old_param_data
if mbuf:
if hasattr(p, 'get_high_precision_init_val'):
mbuf.set_item(item_id, p.get_high_precision_init_val())
p.clear_high_precision_init_val()
else:
mbuf.set_item(item_id, p)
if wbuf and wbuf.is_data_distributed:
"""
When MCore Custom FSDP `optim_grads_params` is enabled,
it is necessary to save the tensor local shard. This local shard is
accessible through the `fully_shard_param_local_shard`
attribute of the tensor.
This attribute contains the local shard of the fully
sharded parameter, which is essential for correctly
saving and loading the model state when using
`optim_grads_params` with FSDP.
Example:
>>> # Assuming `tensor` is a fully sharded parameter
>>> local_shard = tensor.fully_shard_param_local_shard
>>> # Save the local shard as needed
"""
local_shard = wbuf.get_item(item_id, only_shard=True)
local_shard.fsdp_shard_orig_param = p
p.fully_shard_param_local_shard = local_shard
p.fully_shard_param_local_index = wbuf.locate_item_in_global_item(item_id)
if self.ddp_config.num_distributed_optimizer_instances > 1:
p.fsdp_instance_id = torch.distributed.get_rank(
self.inter_data_parallel_group
)
else:
p.fsdp_instance_id = 0
if wbuf and wbuf.is_data_distributed:
wbuf.free_bucket_storage()
# Allocate the main_weight buffer and main_grad buffer data in one buffer.
if self.buffer_all_in_one:
with self.mem_alloc_context():
self.buffer = {
torch.float32: torch.empty(
buffer_size[torch.float32], dtype=torch.float32, device=self.device
),
torch.float16: torch.empty(
buffer_size[torch.float16], dtype=torch.float16, device=self.device
),
torch.bfloat16: torch.empty(
buffer_size[torch.bfloat16], dtype=torch.bfloat16, device=self.device
),
"float8": torch.empty(
buffer_size["float8"], dtype=torch.uint8, device=self.device
),
}
offset = {torch.float32: 0, torch.float16: 0, torch.bfloat16: 0, "float8": 0}
def _alloc(dtype, size):
if self.buffer_all_in_one:
if dtype == torch.uint8:
dtype = "float8"
data = self.buffer[dtype][offset[dtype] : offset[dtype] + size]
offset[dtype] += size
return data
return torch.empty(size, dtype=dtype, device=self.device)
# Initialize the main grad buffer data of each parameter group.
for group in self.parameter_groups:
gbuf = group.main_grad_buffer
if not gbuf:
continue
with self.mem_alloc_context():
gbuf.data = _alloc(gbuf.dtype, gbuf.data_size)
gbuf.data.zero_()
for item_id, p in enumerate(group.params):
p.fsdp_managed_main_grad = gbuf.get_item(item_id)
p._gbuf = gbuf
p._item_id = item_id
def main_grad_getter(p):
# Make sure main_grad memory storage ready.
bucket = p._gbuf.fetch_bucket()
gbuf = p._gbuf
item_id = p._item_id
return gbuf.get_item_from_bucket(bucket, item_id).view(p.shape)
setattr(p.__class__, 'main_grad', property(main_grad_getter))
if gbuf.is_data_distributed:
gbuf.free_bucket_storage()
gc.collect()
torch.cuda.empty_cache()
def _reset_parameters(self, old_params, new_params):
assert len(old_params) == len(new_params)
param_map = {}
for old_param, new_param in zip(old_params, new_params):
param_map[old_param] = new_param
self.param_to_name[new_param] = self.param_to_name[old_param]
del self.param_to_name[old_param]
self.param_to_param_group[new_param] = self.param_to_param_group[old_param]
del self.param_to_param_group[old_param]
self.param_to_direct_module[new_param] = self.param_to_direct_module[old_param]
del self.param_to_direct_module[old_param]
for item_id, p in enumerate(self.params):
if p in param_map:
new_p = param_map[p]
self.params[item_id] = new_p
for group in self.parameter_groups:
for item_id, p in enumerate(group.params):
if p not in param_map:
continue
new_p = param_map[p]
group.params[item_id] = new_p
for buf in [
group.model_weight_buffer,
group.main_weight_buffer,
group.main_grad_buffer,
]:
if buf is None:
continue
buf.param_idx[new_p] = buf.param_idx[p]
del buf.param_idx[p]
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale the gradient data by `scaling_factor`."""
for group in self.parameter_groups:
if group.main_grad_buffer is None:
continue
group.main_grad_buffer.data *= scaling_factor
self.update_main_grads()
def zero_grad(self):
"""
Zero out the underlying grad_buffer and reset all buckets in preparation
for the next iteration of training.
"""
for _, param in self.optimizer_named_parameters:
if param.grad is not None and param.grad._base is None:
# For tensors that are not referenced, trying to use storage
# resize to make memory free immediately.
_free_storage(param.grad)
param.grad = None
for group in self.parameter_groups:
if group.main_grad_buffer is None:
continue
group.main_grad_buffer.data.zero_()
def _init_optimizer_named_parameters(self) -> List[Tuple[str, torch.nn.Parameter]]:
named_parameters = []
for pg in self.parameter_groups:
if pg.main_grad_buffer is None:
continue
optimizer_state_is_shard = pg.main_grad_buffer.is_data_distributed or (
pg.main_weight_buffer and pg.main_weight_buffer.is_data_distributed
)
for item_id, orig_param in enumerate(pg.params):
if pg.main_weight_buffer:
param = pg.main_weight_buffer.get_item(
item_id, only_shard=optimizer_state_is_shard
)
elif pg.model_weight_buffer:
param = pg.model_weight_buffer.get_item(
item_id, only_shard=optimizer_state_is_shard
)
else:
param = orig_param
def set_param_attribute_closure(param, orig_param):
def set_param_attribute():
for attr_name in [
'requires_grad',
'sequence_parallel',
'shared',
'tensor_model_parallel',
'partition_dim',
'partition_stride',
'is_embedding_or_output_parameter',
]:
if hasattr(orig_param, attr_name):
setattr(param, attr_name, getattr(orig_param, attr_name))
return set_param_attribute
setattr(param, 'reset_attribute', set_param_attribute_closure(param, orig_param))
setattr(param, 'orig_param', orig_param)
param.reset_attribute()
named_parameters.append((self.param_to_name[orig_param], param))
return named_parameters
def update_main_grads(self):
"""Update the main gradients for preparing the optimizer step."""
update_shard_main_grad = self.ddp_config.data_parallel_sharding_strategy in [
'optim',
'optim_grads',
'optim_grads_params',
]
for _, param in self.optimizer_named_parameters:
param.reset_attribute()
orig_param = param.orig_param
group = self.parameter_groups[self.param_to_param_group[orig_param]]
item_id = group.main_grad_buffer.param_idx[orig_param]
optimizer_grad = group.main_grad_buffer.get_item(
item_id, only_shard=update_shard_main_grad
)
# The presence of main_grad_buffer but no main_weight_buffer means
# that a precision-aware optimizer is used.
if group.main_weight_buffer is None:
setattr(
param, 'decoupled_grad', optimizer_grad if optimizer_grad.numel() > 0 else None
)
else:
setattr(
param,
'grad',
optimizer_grad.to(param.dtype) if optimizer_grad.numel() > 0 else None,
)
@property
def num_buckets(self):
"""Return the number of buckets."""
return len(self.parameter_groups)
@torch.no_grad()
def copy_main_weights_to_model_weights(self):
"""Update the model weights from the main weights."""
for pg in self.parameter_groups:
mbuf = pg.main_weight_buffer
wbuf = pg.model_weight_buffer
if mbuf is None:
continue
fp8_params = []
shard_fp32_from_fp8 = []
shard_offsets_in_fp8 = []
shard_model_params = []
for param in pg.params:
item_id = mbuf.param_idx[param]
if wbuf:
if wbuf.is_data_distributed or mbuf.is_data_distributed:
model_param = wbuf.get_item(item_id, only_shard=True)
main_weight = mbuf.get_item(item_id, only_shard=True)
else:
model_param = wbuf.get_item(item_id)
main_weight = mbuf.get_item(item_id)
else:
assert not mbuf.is_data_distributed
model_param = param
main_weight = pg.main_weight_buffer.get_item(item_id)
if is_float8tensor(param):
fp8_params.append(param)
if model_param.numel() == 0:
shard_fp32_from_fp8.append(None)
shard_offsets_in_fp8.append(None)
shard_model_params.append(None)
else:
shard_fp32_from_fp8.append(main_weight)
shard_offsets_in_fp8.append(wbuf.locate_item_in_global_item(item_id)[0])
shard_model_params.append(model_param)
continue
if model_param.numel() > 0:
model_param.data.copy_(main_weight.view(model_param.shape))
quantize_param_shard(
fp8_params,
shard_fp32_from_fp8,
shard_offsets_in_fp8,
wbuf.data_parallel_group,
shard_model_params,
)
@torch.no_grad()
def copy_model_weights_to_main_weights(self):
"""Copy the model weights to the main weights."""
for group in self.parameter_groups:
mbuf = group.main_weight_buffer
if mbuf is None:
continue
wbuf = group.model_weight_buffer
if mbuf.is_data_distributed:
copyin_data = wbuf.get_shard_from_local_buffer()
else:
copyin_data = wbuf.data
assert mbuf.data.numel() == copyin_data.numel(), (
f"Master weight buffer size {mbuf.data.numel()} does not match "
f"model weight buffer size {copyin_data.numel()}"
)
mbuf.data.copy_(copyin_data.data)
def all_gather_parameters(self, async_op: bool = True):
"""All gather the parameters.
Args:
async_op (bool, optional): Whether to do the all-reduce
asynchronously. Defaults to False.
"""
assert all(
[not g.model_weight_buffer.is_data_distributed for g in self.parameter_groups]
), 'all_gather_parameters() should only be called when parameters are not sharded.'
all_gather_ops = []
for g in self.parameter_groups:
shard = g.model_weight_buffer.get_shard_from_local_buffer()
all_gather_handler = torch.distributed.all_gather_into_tensor(
output_tensor=g.model_weight_buffer.data,
input_tensor=shard,
group=g.model_weight_buffer.data_parallel_group,
async_op=async_op,
)
if async_op:
all_gather_ops.append(all_gather_handler)
for op in all_gather_ops:
op.wait()
def reduce_scatter_gradients(self, async_op: bool = True):
"""Reduce scatter the gradients.
Args:
async_op (bool, optional): Whether to do the all-reduce
asynchronously. Defaults to False.
"""
assert all(
[not g.main_grad_buffer.is_data_distributed for g in self.parameter_groups]
), 'reduce_scatter_gradients() should only be called when gradients are not sharded.'
reduce_scatter_ops = []
for g in self.parameter_groups:
gbuf = g.main_grad_buffer
if gbuf is None:
continue
scaling_factor = gbuf.gradient_scaling_factor
reduce_op = gradient_reduce_preprocessing(gbuf.data, scaling_factor, self.ddp_config)
reduce_scatter_handler = torch.distributed.reduce_scatter_tensor(
output=gbuf.get_shard_from_local_buffer(),
input=gbuf.data,
op=reduce_op,
group=g.main_grad_buffer.data_parallel_group,
async_op=async_op,
)
if async_op:
reduce_scatter_ops.append(reduce_scatter_handler)
for op in reduce_scatter_ops:
op.wait()
def all_reduce_gradients(self, async_op: bool = False):
"""All reduce the gradients.
Args:
async_op (bool, optional): Whether to do the all-reduce
asynchronously. Defaults to False.
"""
assert all(
[
not g.main_grad_buffer.is_data_distributed
for g in self.parameter_groups
if g.main_grad_buffer
]
), 'all_reduce_gradients() should only be called when gradients are not sharded.'
all_reduce_ops = []
for g in self.parameter_groups:
gbuf = g.main_grad_buffer
if gbuf is None:
continue
scaling_factor = gbuf.gradient_scaling_factor
reduce_op = gradient_reduce_preprocessing(gbuf.data, scaling_factor, self.ddp_config)
all_reduce_handler = torch.distributed.all_reduce(
gbuf.data, op=reduce_op, group=gbuf.data_parallel_group, async_op=async_op
)
if async_op:
all_reduce_ops.append(all_reduce_handler)
for op in all_reduce_ops:
op.wait()
class BucketStatus(Enum):
"""
An enumeration of possible statuses for a data-parallel communication bucket.
Attributes:
EMPTY (int): The bucket is empty and not in use.
COMMUNICATING (int): The bucket is currently being used for communication.
READY_TO_USE (int): The bucket is filled with data and ready for use.
"""
EMPTY = 1
COMMUNICATING = 2
READY_TO_USE = 3
class GradReducePipeline:
"""
Pipeline for reducing gradients.
"""
def __init__(
self,
param_and_grad_buffer: ParamAndGradBuffer,
rs_stream: Optional[torch.cuda.Stream] = None,
check_nans: bool = False,
inter_fsdp_group_grad_reduce: bool = False,
) -> None:
self.buffer = param_and_grad_buffer
self.grad_reduce_queue = []
self.bucket_status = {
i: BucketStatus.EMPTY
for i in range(self.buffer.num_buckets)
if self.buffer.parameter_groups[i].main_grad_buffer
}
self.bucket_grad_ready_params = [set() for _ in range(self.buffer.num_buckets)]
self.rs_stream = rs_stream
self.check_nans = check_nans
self.inter_fsdp_group_grad_reduce = inter_fsdp_group_grad_reduce
if inter_fsdp_group_grad_reduce:
self.hsdp_all_reduce_stream = torch.cuda.Stream()
@property
def num_buckets(self):
"""Return the number of buckets."""
return self.buffer.num_buckets
def reset(self):
"""Handle the processing tasks and reset the pipeline."""
self.wait_for_previous_grad_reduce(0)
for bucket_id, grad_ready_params in enumerate(self.bucket_grad_ready_params):
param_list = self.buffer.parameter_groups[bucket_id].params
n_params = len(param_list)
param_to_name = self.buffer.param_to_name
assert len(grad_ready_params) == 0, (
f"Found {len(grad_ready_params)} out of {n_params} parameters that are ready for "
f"reduce-scatter/all-reduce, but the pipeline is being reset. "
f"grad_ready_params: {[param_to_name[p] for p in grad_ready_params]} "
f"param_list: {[param_to_name[p] for p in param_list]}"
)
for bucket_id, _ in self.bucket_status.items():
gbuf = self.buffer.parameter_groups[bucket_id].main_grad_buffer
gbuf.free_bucket_storage()
self.bucket_status[bucket_id] = BucketStatus.EMPTY
def reduce_gradients(
self,
params: List[torch.Tensor],
suggested_queue_capacity: Optional[int] = None,
inter_fsdp_group_grad_reduce: bool = False,
async_grad_reduce: bool = True,
):
"""Reduce the gradients for the given parameters.
Args:
params (List[torch.Tensor]): The parameters.
suggested_queue_capacity (int, optional): The suggested queue capacity.
Defaults to None.
inter_fsdp_group_grad_reduce (bool, optional): Whether to use inter-group
gradient reduction. Defaults to False.
async_grad_reduce (bool, optional): Whether to do the gradient-reduce
asynchronously. Defaults to True.
"""
for param in params:
bucket_id = self.buffer.param_to_param_group[param]
param_group = self.buffer.parameter_groups[bucket_id]
if not param.requires_grad:
assert param_group.requires_grad is False, (
f"Param {self.buffer.param_to_name[param]} has requires_grad=False, "
f"but it is in a parameter group with requires_grad=True."
)
continue
assert param_group.requires_grad, (
f"Param {self.buffer.param_to_name[param]} has requires_grad=True, "
f"but it is in a parameter group with requires_grad=False."
)
# Mark grad as ready for reduce-scatter/all-reduce.
self.bucket_grad_ready_params[bucket_id].add(param)
if len(self.bucket_grad_ready_params[bucket_id]) == len(param_group.params):
self.wait_for_previous_grad_reduce(
suggested_queue_capacity=suggested_queue_capacity
)
self.mark_bucket_ready(
bucket_id, inter_fsdp_group_grad_reduce, async_op=async_grad_reduce
)
def wait_for_previous_grad_reduce(
self, suggested_queue_size: int = 1, suggested_queue_capacity: Optional[int] = None
):
"""
Wait for the previous reduce-scatter/all-reduce to finish.
Args:
suggested_queue_size (int, optional): The recommended queue size. Defaults to 1.
suggested_queue_capacity (Optional[int], optional): The recommended queue capacity.
Defaults to None.
"""
if suggested_queue_capacity is not None:
queue_space = sum(
[
self.buffer.parameter_groups[bucket_id].main_grad_buffer.bucket_index.size
for _, _, bucket_id in self.grad_reduce_queue
]
)
while queue_space > suggested_queue_capacity:
grad_reduce_event, free_up_grad_bucket, bucket_id = self.grad_reduce_queue.pop(0)
grad_reduce_event.wait()
free_up_grad_bucket()
queue_space -= self.buffer.parameter_groups[
bucket_id
].main_grad_buffer.bucket_index.size
else:
suggested_queue_size = max(0, min(suggested_queue_size, self.buffer.num_buckets - 1))
while len(self.grad_reduce_queue) > suggested_queue_size:
grad_reduce_event, free_up_grad_bucket, _ = self.grad_reduce_queue.pop(0)
grad_reduce_event.wait()
free_up_grad_bucket()
if suggested_queue_size == 0 and self.inter_fsdp_group_grad_reduce:
torch.cuda.current_stream().wait_stream(self.hsdp_all_reduce_stream)
def _enforce_double_buffer_limit(self, add_buckets):
if not self.buffer.ddp_config.fsdp_double_buffer:
return
param_groups = self.buffer.parameter_groups
double_buf_units = set()
for bucket_id in add_buckets:
fsdp_unit_id = param_groups[bucket_id].fsdp_unit_id
if fsdp_unit_id in self.buffer.double_buf_units:
double_buf_units.add(fsdp_unit_id)
assert len(double_buf_units) <= 2, (
f"Double buffer limit exceeded. " f"Current double_buf_units: {double_buf_units}."
)
keep_n = len(self.grad_reduce_queue)
for _, _, bucket_id in reversed(self.grad_reduce_queue):
fsdp_unit_id = param_groups[bucket_id].fsdp_unit_id
double_buf_units.add(fsdp_unit_id)
if len(double_buf_units) > 2:
keep_n -= 1
self.wait_for_previous_grad_reduce(keep_n)
def _bucket_group_gradient_reduce(
self,
bucket_group: List[int],
async_op: bool = False,
inter_fsdp_group_grad_reduce: bool = False,
):
"""Mark the bucket ready for reduce-scatter/all-reduce, if all bucket in
the bucket group are ready, then do the reduce-scatter/all-reduce.
Args:
bucket_id (int): The bucket to be marked.
async_rs (bool, optional): Whether to do the reduce-scatter/all-reduce
asynchronously. Defaults to False.
Returns:
bool: True if the bucket is go for reduce-scatter/all-reduce.
"""
# When using FSDP double buffer, waiting for the necessary bucket to be
# released ensures that our double buffer will not explode due to too
# many empty bucket requests.
if self.buffer.ddp_config.fsdp_double_buffer:
self._enforce_double_buffer_limit(bucket_group)
current_stream = torch.cuda.current_stream()
reduce_scatter_stream = (
self.rs_stream if self.rs_stream is not None else torch.cuda.current_stream()
)
reduce_scatter_stream.wait_stream(current_stream)
dp_group = self.buffer.parameter_groups[
bucket_group[0]
].main_grad_buffer.data_parallel_group
with torch.cuda.stream(reduce_scatter_stream):
with _coalescing_manager(dp_group, async_ops=async_op) as coalescing_event:
grad_shards = {}
for bucket_id in bucket_group:
gbuf = self.buffer.parameter_groups[bucket_id].main_grad_buffer
bucket = gbuf.fetch_bucket()
scaling_factor = gbuf.gradient_scaling_factor
reduce_op = gradient_reduce_preprocessing(
gbuf.data, scaling_factor, gbuf.ddp_config
)
if gbuf.ddp_config.data_parallel_sharding_strategy == 'no_shard':
torch.distributed.all_reduce(
bucket.data, op=reduce_op, group=gbuf.data_parallel_group
)
else:
grad_shard = gbuf.get_shard_from_bucket(bucket)
# pylint: disable=C0301
# The `grad_shard`` is part of `bucket.data`` and the following
# new empty is important for memory safety, when using
# TORCH_NCCL_AVOID_RECORD_STREAMS=1.
# For reference: https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486
if not self.buffer.ddp_config.fsdp_double_buffer:
grad_shard = torch.empty_like(grad_shard)
torch.distributed.reduce_scatter_tensor(
output=grad_shard,
input=bucket.data,
op=reduce_op,
group=gbuf.data_parallel_group,
)
grad_shards[bucket_id] = grad_shard
self.bucket_status[bucket_id] = BucketStatus.COMMUNICATING
coalescing_event.wait()
for bucket_id in bucket_group:
# Local gradient accumulate
gbuf = self.buffer.parameter_groups[bucket_id].main_grad_buffer
if gbuf.ddp_config.data_parallel_sharding_strategy != 'no_shard':
# Gradient accumulate on local buffer
local_buffer = gbuf.get_shard_from_local_buffer()
local_buffer += grad_shards[bucket_id]
reduce_scatter_view_out_event = reduce_scatter_stream.record_event()
# Gradient reduction within the model replication domain
if inter_fsdp_group_grad_reduce:
ddp_config = self.buffer.ddp_config
assert ddp_config.data_parallel_sharding_strategy != 'no_shard'
self.hsdp_all_reduce_stream.wait_stream(reduce_scatter_stream)
inter_data_parallel_group = self.buffer.parameter_groups[
bucket_group[0]
].main_grad_buffer.inter_data_parallel_group
with torch.cuda.stream(self.hsdp_all_reduce_stream):
with _coalescing_manager(inter_data_parallel_group):
for bucket_id in bucket_group:
gbuf = self.buffer.parameter_groups[bucket_id].main_grad_buffer
grad_local_buffer = gbuf.get_shard_from_local_buffer()
if ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG
else:
reduce_op = torch.distributed.ReduceOp.SUM
torch.distributed.all_reduce(
grad_local_buffer, group=gbuf.inter_data_parallel_group, op=reduce_op
)
free_up_grad_bucket_func = {}
for bucket_id in bucket_group:
def get_closure(bucket_id):
def free_up_grad_bucket():
self.bucket_grad_ready_params[bucket_id] = set()
gbuf = self.buffer.parameter_groups[bucket_id].main_grad_buffer
if gbuf.is_data_distributed:
gbuf.free_bucket_storage()
self.bucket_status[bucket_id] = BucketStatus.EMPTY
return free_up_grad_bucket
free_up_grad_bucket_func[bucket_id] = get_closure(bucket_id)
if async_op:
for bucket_id, free_up_grad_bucket in free_up_grad_bucket_func.items():
self.grad_reduce_queue.append(
(reduce_scatter_view_out_event, free_up_grad_bucket, bucket_id)
)
return
reduce_scatter_view_out_event.wait()
for free_up_grad_bucket in free_up_grad_bucket_func.values():
free_up_grad_bucket()
def mark_bucket_ready(
self, bucket_id: int, inter_fsdp_group_grad_reduce: bool = False, async_op: bool = True
) -> bool:
"""Mark the bucket ready for gradient reduce, if all bucket in the bucket group
are ready, reduce-scatter or all-reduce gradient bucket, in the case of HSDP,
there is an additional all-reduce in the model replication domain.
Args:
bucket_id (int): The bucket to be marked ready to reduce-scatter or
all-reduce.
inter_fsdp_group_grad_reduce (bool, optional): Whether to use inter-group
gradient reduction. Defaults to False.
async_op (bool, optional): Whether to do the gradient-reduce
asynchronously. Defaults to True.
Returns:
bool: True if the bucket is go for reduce-scatter/all-reduce.
"""
# Prepare bucket group for gradient reduce. Note that the
# some bucket parameters do not require grad, so we need to
# remove them from the bucket group.
bucket_group = self.buffer.bucket_to_bucket_group[bucket_id]
bucket_group = [i for i in bucket_group if self.buffer.parameter_groups[i].main_grad_buffer]
# If any bucket in the bucket group is not ready, skip the gradient reduce
# waiting for the bucket group to be all ready before executing.
for bucket_id in bucket_group:
param_group = self.buffer.parameter_groups[bucket_id]
if len(self.bucket_grad_ready_params[bucket_id]) != len(param_group.params):
return False
self._bucket_group_gradient_reduce(
bucket_group,
async_op=async_op,
inter_fsdp_group_grad_reduce=inter_fsdp_group_grad_reduce,
)
return True
class PrefetchOrder(Enum):
"""
An enumeration of possible prefetch orders for data-parallel operations.
Attributes:
FORWARD_PASS_ORDER (int): Prefetch in the order of forward pass computation.
BACKWARD_PASS_ORDER (int): Prefetch in the order of backward pass computation.
"""
FORWARD_PASS_ORDER = 0
BACKWARD_PASS_ORDER = 1
class AllGatherPipeline:
"""
Pipeline for all-gathering parameters.
"""
def __init__(self, param_and_grad_buffer: ParamAndGradBuffer) -> None:
self.buffer = param_and_grad_buffer
self.param_gather_event_map = {}
self.bucket_status = {i: BucketStatus.EMPTY for i in range(self.buffer.num_buckets)}
self.bucket_can_be_released = {i: False for i in range(self.buffer.num_buckets)}
self.bucket_to_bucket_group = {}
group_id = 0
for bucket_group in self.buffer.bucket_to_bucket_group.values():
new_group = False
for bucket_id in bucket_group:
if bucket_id not in self.bucket_to_bucket_group:
new_group = True
break
if new_group:
group_id += 1
for bucket_id in bucket_group:
self.bucket_to_bucket_group[bucket_id] = group_id
@property
def num_buckets(self):
"""Return the number of buckets."""
return self.buffer.num_buckets
def reset(self):
"""Reset the pipeline state."""
if len(self.param_gather_event_map) > 0:
warnings.warn(
"There are still pending all-gather tasks, process them. "
f"Bucket status: {self.bucket_status}.",
UserWarning,
)
while len(self.param_gather_event_map) > 0:
bucket_id = next(iter(self.param_gather_event_map))
self.wait_bucket_ready(bucket_id)
for bucket_id in self.bucket_can_be_released:
self.bucket_can_be_released[bucket_id] = True
self.recycle_unused_buckets()
assert all([status is BucketStatus.EMPTY for status in self.bucket_status.values()]), (
f"There are still working buckets, it is not safe to reset. "
f"bucket_status: {self.bucket_status}."
)
assert all(
[not can_be_released for can_be_released in self.bucket_can_be_released.values()]
), (
f"The bucket can be released table is in an abnormal state, not safe to reset. "
f"bucket_can_be_released: {self.bucket_can_be_released}."
)
def all_gather_params(
self,
params: List[torch.Tensor],
prefetch: bool = False,
prefetch_order: PrefetchOrder = PrefetchOrder.FORWARD_PASS_ORDER,
suggested_AG_prefetch_size: Optional[int] = None,
async_param_gather: bool = True,
):
"""All-gather the params. If prefetch is enabled, prefetch next buckets
in the order of `prefetch_order`.
Args:
params (List[torch.Tensor]): The list of params to be all-gathered.
prefetch (bool, optional): Whether to prefetch the next bucket. Defaults to False.
prefetch_order (PrefetchOrder, optional): The order of prefetching.
Defaults to PrefetchOrder.FORWARD_PASS_ORDER.
suggested_AG_prefetch_size (Optional[int], optional):
The suggested prefetch size for all-gathering. Defaults to None.
"""
if len(params) == 0:
return
ag_buckets = [self.buffer.param_to_param_group[item] for item in params]
ag_buckets = list(sorted(set(ag_buckets)))
parameter_groups = self.buffer.parameter_groups
if self.buffer.ddp_config.fsdp_double_buffer:
double_buf_units = set()
for bucket_id in ag_buckets:
fsdp_unit_id = parameter_groups[bucket_id].fsdp_unit_id
if fsdp_unit_id in self.buffer.double_buf_units:
double_buf_units.add(fsdp_unit_id)
if len(double_buf_units) > 2:
raise ValueError(
f"{double_buf_units} FSDP units were requested, "
"but double buffers can support no more than 2 FSDP units."
)
# If prefetch is enabled, we will add prefetch buckets to ag_buckets.
if prefetch:
def next_bucket_id(ag_buckets):
if prefetch_order == PrefetchOrder.FORWARD_PASS_ORDER:
bucket_id = ag_buckets[0] + 1
for i in ag_buckets[1:]:
if i != bucket_id:
break
bucket_id += 1
else:
bucket_id = ag_buckets[-1] - 1
for i in reversed(ag_buckets[:-1]):
if i != bucket_id:
break
bucket_id -= 1
if bucket_id < 0 or bucket_id >= self.buffer.num_buckets:
return None
return bucket_id
def need_skip_prefetch(bucket_id):
# If use double buffer, we need to check if the next bucket
# is exceeding the coverage of the double buffer.
if self.buffer.ddp_config.fsdp_double_buffer:
fsdp_unit_id = parameter_groups[bucket_id].fsdp_unit_id
double_buf_units.add(fsdp_unit_id)
if len(double_buf_units) > 2:
# Prefetching the next bucket will exceed the coverage of
# the double buffer, so we need to stop prefetching.
return True
return False
if suggested_AG_prefetch_size is not None:
bucket_id = next_bucket_id(ag_buckets)
while bucket_id is not None:
all_gather_size = sum(
[
parameter_groups[i].model_weight_buffer.bucket_index.size
for i in ag_buckets
]
)
if all_gather_size >= suggested_AG_prefetch_size:
break
if need_skip_prefetch(bucket_id):
break
ag_buckets.extend(self.buffer.bucket_to_bucket_group[bucket_id])
ag_buckets = list(sorted(set(ag_buckets)))
bucket_id = next_bucket_id(ag_buckets)
else:
bucket_id = next_bucket_id(ag_buckets)
if need_skip_prefetch(bucket_id):
bucket_id = None
if bucket_id is not None:
ag_buckets.extend(self.buffer.bucket_to_bucket_group[bucket_id])
ag_buckets = list(sorted(set(ag_buckets)))
ag_buckets = [i for i in ag_buckets if self.bucket_status[i] == BucketStatus.EMPTY]
if len(ag_buckets) == 0:
return
# Divide buckets into aggregate groups
bucket_group_to_buckets = {}
for bucket_id in ag_buckets:
group_id = self.bucket_to_bucket_group[bucket_id]
if group_id not in bucket_group_to_buckets:
bucket_group_to_buckets[group_id] = []
bucket_group_to_buckets[group_id].append(bucket_id)
# Coalesce all-gather operations for all buckets in the same data-parallel-group
for _, buckets in bucket_group_to_buckets.items():
param_group = parameter_groups[buckets[0]]
dp_group = param_group.model_weight_buffer.data_parallel_group
with _coalescing_manager(dp_group, async_ops=async_param_gather) as coalescing_event:
for bucket_id in buckets:
self.async_bucket_gather(bucket_id)
# reset param gather event with coalescing event
for bucket_id in buckets:
_, mark_bucket_ready_to_use = self.param_gather_event_map[bucket_id]
self.param_gather_event_map[bucket_id] = (
coalescing_event,
mark_bucket_ready_to_use,
)
# Wait for all-gather to finish
if not async_param_gather:
for bucket_id in buckets:
self.wait_bucket_ready(bucket_id)
def wait_bucket_ready(self, bucket_id, empty_ok=False):
"""Wait for the bucket to be ready."""
if self.bucket_status[bucket_id] == BucketStatus.READY_TO_USE:
return
if self.bucket_status[bucket_id] == BucketStatus.EMPTY:
if empty_ok:
return
raise ValueError(f"Bucket {bucket_id} is empty.")
param_gather_event, mark_bucket_ready_to_use = self.param_gather_event_map.pop(bucket_id)
param_gather_event.wait()
mark_bucket_ready_to_use()
@torch.no_grad()
def release_bucket(self, bucket_id: int):
"""Release the bucket."""
if self.bucket_status[bucket_id] == BucketStatus.EMPTY:
return
if self.bucket_status[bucket_id] == BucketStatus.COMMUNICATING:
raise ValueError(f"Bucket {bucket_id} is communicating.")
wbuf = self.buffer.parameter_groups[bucket_id].model_weight_buffer
wbuf.free_bucket_storage()
self.bucket_status[bucket_id] = BucketStatus.EMPTY
def recycle_unused_buckets(self):
"""Recycle the unused buckets."""
for bucket_id, can_be_released in self.bucket_can_be_released.items():
if can_be_released:
self.release_bucket(bucket_id)
self.bucket_can_be_released[bucket_id] = False
@torch.no_grad()
def async_bucket_gather(self, bucket_id: int) -> None:
"""All-gather the bucket and set the items."""
self.bucket_can_be_released[bucket_id] = False
if self.bucket_status[bucket_id] != BucketStatus.EMPTY:
return
self.bucket_status[bucket_id] = BucketStatus.COMMUNICATING
wbuf = self.buffer.parameter_groups[bucket_id].model_weight_buffer
# Lazy release the unused buckets.
self.recycle_unused_buckets()
bucket = wbuf.fetch_bucket(and_allocate_params_data=True)
param_gather_event = torch.distributed.all_gather_into_tensor(
output_tensor=bucket.data,
input_tensor=wbuf.get_shard_from_local_buffer(),
group=wbuf.data_parallel_group,
async_op=True,
)
def get_closure(bucket_id):
@torch.no_grad()
def mark_bucket_ready_to_use():
self.bucket_status[bucket_id] = BucketStatus.READY_TO_USE
return mark_bucket_ready_to_use
mark_bucket_ready_to_use = get_closure(bucket_id)
self.param_gather_event_map[bucket_id] = (param_gather_event, mark_bucket_ready_to_use)
@torch.no_grad()
def gradient_reduce_preprocessing(grad_data, scaling_factor, ddp_config):
"""
Gradient reduce preprocessing for gradient averaging and gradient scaling.
"""
if scaling_factor is None:
reduce_op = torch.distributed.ReduceOp.SUM
elif ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG
elif ddp_config.gradient_reduce_div_fusion and grad_data.dtype != torch.bfloat16:
reduce_op = torch.distributed._make_nccl_premul_sum(scaling_factor)
else:
grad_data.mul_(scaling_factor)
reduce_op = torch.distributed.ReduceOp.SUM
return reduce_op
def check_gpu_memory(threshold=0.9):
"""
Check if the GPU memory is over the threshold.
Args:
threshold (float, optional): The threshold to check if the GPU memory is over.
Defaults to 0.9.
Returns:
bool: True if the GPU memory is over the threshold.
"""
if not torch.cuda.is_available():
return False
device = torch.cuda.current_device()
allocated = torch.cuda.memory_allocated(device)
reserved = torch.cuda.memory_reserved(device)
total = torch.cuda.get_device_properties(device).total_memory
allocated_ratio = allocated / total
reserved_ratio = reserved / total
near_full = allocated_ratio >= threshold or reserved_ratio >= threshold
if near_full:
log_on_each_pipeline_stage(
logger,
logging.INFO,
f"GPU Memory: Allocated: {allocated_ratio:.2%}, Reserved: {reserved_ratio:.2%}",
)
return near_full
class ResetParametersContext:
"""
Context manager for resetting parameters for meta device initialization module.
"""
def __init__(self, init_param_with_fp8=False, with_cuda_rng_tracker=False):
self.init_param_with_fp8 = init_param_with_fp8
self.with_cuda_rng_tracker = with_cuda_rng_tracker
def __enter__(self):
self.stack = ExitStack()
if self.init_param_with_fp8:
args = {"enabled": True}
if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters:
args["preserve_high_precision_init_val"] = True
self.stack.enter_context(fp8_model_init(**args))
if self.with_cuda_rng_tracker:
self.stack.enter_context(get_cuda_rng_tracker().fork())
return self
def __exit__(self, *exc_details):
self.stack.__exit__(*exc_details)
def override_sharded_param_methods_with_safety_checks(params, all_gather_pipeline):
"""
Override the methods of the parameters to prevent undefined behavior.
Args:
params (List[torch.Tensor]): The parameters to add hint on shard to functions.
all_gather_pipeline (AllGatherPipeline): The all-gather pipeline.
"""
for p in params:
to_function = p.to
cpu_function = p.cpu
def override_sharded_param_to_function_closure(p, to_function):
def override_sharded_param_to_function(*args, **kwargs):
bucket_id = all_gather_pipeline.buffer.param_to_param_group[p]
status = all_gather_pipeline.bucket_status[bucket_id]
if status == BucketStatus.READY_TO_USE:
return to_function(*args, **kwargs)
raise RuntimeError(
"This parameter is already shard by MCore FSDP and the "
"shared-state parameter does not support 'to' function."
"please define the dtype and device of the parameter before FSDP wrap."
)
return override_sharded_param_to_function
setattr(p, 'to', override_sharded_param_to_function_closure(p, to_function))
def override_sharded_param_cpu_function_closure(p, cpu_function):
def override_sharded_param_cpu_function(*args, **kwargs):
bucket_id = all_gather_pipeline.buffer.param_to_param_group[p]
status = all_gather_pipeline.bucket_status[bucket_id]
if status == BucketStatus.READY_TO_USE:
return cpu_function(*args, **kwargs)
warnings.warn(
"The parameters are sharded by MCore FSDP, and no actual "
"cpu operation is performed."
)
return torch.empty([], device='cpu')
return override_sharded_param_cpu_function
setattr(p, 'cpu', override_sharded_param_cpu_function_closure(p, cpu_function))
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from contextlib import contextmanager
import torch
from ..transformer.module import MegatronModule
from ..transformer.transformer_config import TransformerConfig
class _BaseDataParallel(MegatronModule):
"""A template class for DistributedDataParallel implementations."""
def __init__(self, config: TransformerConfig, module: torch.nn.Module):
super().__init__(config=config)
self.module = module
def forward(self, *inputs, **kwargs):
"""
Calls the wrapped module's forward() method.
"""
return self.module(*inputs, **kwargs)
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
try:
yield
finally:
pass
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
pass
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale all gradients inside the buffers by `scaling_factor`."""
pass
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
pass
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
pass
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
pass
def state_dict(self, prefix='', keep_vars=False, destination=None):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars, destination=destination)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""
Returns wrapped module's state_dict for checkpoint saving.
"""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
self.module.load_state_dict(state_dict, strict=strict)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from contextlib import contextmanager
from typing import Optional
import torch
from .. import parallel_state
from ..config_logger import has_config_logger_enabled, log_config_to_disk
from ..fp8_utils import is_float8tensor
from ..process_groups_config import GradCommProcessGroups, ModelCommProcessGroups
from ..transformer.cuda_graphs import is_graph_capturing
from ..transformer.transformer_config import TransformerConfig
from ..utils import log_single_rank
from .data_parallel_base import _BaseDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets
logger = logging.getLogger(__name__)
class DistributedDataParallel(_BaseDataParallel):
"""
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
communication with backprop computation by breaking up full model's gradients into smaller
buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class
also provides the option to do the gradient accumulation in a type other than the param type
(e.g., fp32 for a bf16 model).
Args:
config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
grad_comm_pgs: Optional gradient communication process groups.
model_comm_pgs: Optional model parallel communication process groups.
"""
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
disable_bucketing: bool = False,
grad_comm_pgs: Optional[GradCommProcessGroups] = None,
model_comm_pgs: Optional[ModelCommProcessGroups] = None,
):
super().__init__(config=config, module=module)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.module = module
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if ddp_config.bucket_size is None:
ddp_config.bucket_size = max(
40000000, 1000000 * parallel_state.get_data_parallel_world_size()
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if not ddp_config.overlap_grad_reduce:
ddp_config.bucket_size = None
self.ddp_config = ddp_config
log_single_rank(
logger,
logging.INFO,
f'Setting up DistributedDataParallel with config {self.ddp_config}',
)
if grad_comm_pgs is None and model_comm_pgs is None:
self.dp_group = parallel_state.get_data_parallel_group(
with_context_parallel=False, partial_data_parallel=False
)
self.dp_cp_group = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=False
)
self.intra_dp_cp_group = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=True
)
self.expt_dp_group = parallel_state.get_expert_data_parallel_group()
self.intra_expt_dp_group = parallel_state.get_expert_data_parallel_group(
partial_expert_data_parallel=True
)
if self.ddp_config.num_distributed_optimizer_instances > 1:
self.inter_dist_opt_group = (
parallel_state.get_inter_distributed_optimizer_instance_group()
)
self.pp_group = parallel_state.get_pipeline_model_parallel_group()
self.ep_group = parallel_state.get_expert_model_parallel_group()
elif grad_comm_pgs is not None and model_comm_pgs is not None:
# 1. dp group - this is always required
if not hasattr(grad_comm_pgs, 'dp'):
raise ValueError("dp process group is required but not provided in grad_comm_pgs")
self.dp_group = grad_comm_pgs.dp
# 2. dp_cp group:
# - If provided in grad_comm_pgs, use it
# - Otherwise check context_parallel_size
# - If cp_size is 1, use same as dp
# - If cp_size > 1, raise error as dp_cp is needed
if hasattr(grad_comm_pgs, 'dp_cp'):
self.dp_cp_group = grad_comm_pgs.dp_cp
else:
cp_size = getattr(config, 'context_parallel_size', 1)
if cp_size == 1:
# If no context parallelism, dp_cp is same as dp
self.dp_cp_group = self.dp_group
else:
raise ValueError(
"dp_cp process group is required when context_parallel_size > 1 "
"but not provided in grad_comm_pgs"
)
# 3. Handle expert data parallel group
if hasattr(grad_comm_pgs, 'expt_dp'):
self.expt_dp_group = grad_comm_pgs.expt_dp
else:
# Create a new group with just the current rank
log_single_rank(
logger,
logging.WARNING,
"No expert data parallel group provided in grad_comm_pgs, "
"creating a new one with just the current rank",
)
# Ideally we dont want any expt_dp_group if not using expt_dp
# but downstream code expects.
# this is used to check size and calculate scaling factor.
self.expt_dp_group = torch.distributed.new_group(
ranks=[torch.distributed.get_rank()]
)
# 4. Handle intra_dp_cp, intra_expt_dp, and inter_dist_opt
# based on optimizer instances:
if self.ddp_config.num_distributed_optimizer_instances == 1:
# With a single optimizer instance:
# - intra_dp_cp is same as dp_cp
# - intra_expt_dp is same as expt_dp
# - inter_dist_opt is not needed
self.intra_dp_cp_group = self.dp_cp_group
self.intra_expt_dp_group = self.expt_dp_group
else:
# With multiple optimizer instances, both groups must be provided
if not (
hasattr(grad_comm_pgs, 'intra_dp_cp')
and hasattr(grad_comm_pgs, 'intra_expt_dp')
and hasattr(grad_comm_pgs, 'inter_dist_opt')
):
raise ValueError(
"intra_dp_cp, intra_expt_dp, and inter_dist_opt "
"process groups are required when using multiple optimizer "
"instances (>1) but not provided in grad_comm_pgs"
)
self.intra_dp_cp_group = grad_comm_pgs.intra_dp_cp
self.intra_expt_dp_group = grad_comm_pgs.intra_expt_dp
self.inter_dist_opt_group = grad_comm_pgs.inter_dist_opt
# 5. pp and ep group
if not all([hasattr(model_comm_pgs, 'pp'), hasattr(model_comm_pgs, 'ep')]):
raise ValueError(
"pp and ep process groups are required but not provided in model_comm_pgs"
)
self.pp_group = model_comm_pgs.pp
self.ep_group = model_comm_pgs.ep
else:
raise ValueError(
"Grad and model comm process groups must be provided or both must be None"
)
# Turn off bucketing if we are on a pipeline stage that is not the first (since
# data-parallel communication on these stages is not on the critical path), or if
# disable_bucketing is True (e.g., we might not want to break up model parameters
# into buckets for model chunks after the first in the interleaved schedule).
self.bucket_size = self.ddp_config.bucket_size
if isinstance(self.pp_group, list):
pp_rank = self.pp_group[0].rank()
else:
pp_rank = self.pp_group.rank()
if pp_rank > 0:
self.bucket_size = None
if disable_bucketing:
self.bucket_size = None
self.param_to_bucket_group = {}
# Group parameters by their gradient type.
param_to_name = {}
dense_params = []
expert_parallel_params = []
self.params_with_grad = []
for name, param in self.module.named_parameters():
if not param.requires_grad:
continue
# Track params with grad to enable direct setting
# of param.grad_added_to_main_grad
self.params_with_grad.append(param)
param.grad_added_to_main_grad = False
param_to_name[param] = name
if getattr(param, 'allreduce', True):
dense_params.append(param)
else:
expert_parallel_params.append(param)
def _allocate_buffers_for_parameters(
input_params, data_parallel_group, gradient_scaling_factor
):
param_and_grad_dtype_to_params = {}
param_and_grad_dtype_to_offsets = {}
param_and_grad_dtype_to_indices = {}
# Group parameters by their gradient type.
for param in input_params:
assert param.requires_grad
param_dtype = param.dtype
if is_float8tensor(param):
# Currently TE's Float8Tensor is a wrapper of torch.Tensor. It has a "fake"
# dtype (usually a higher precision dtype such as bfloat16), but its actual
# data is stored in the form of a torch uint8 tensor within the Float8Tensor's
# ".data" attribute. Therefore, when creating the param buffer for fp8 params,
# it is necessary to use torch.uint8, not the "fake" dtype got from
# "param.dtype".
param_dtype = torch.uint8
grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype
params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), [])
params.append(param)
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params
# Get the index of each param among the params with same dtype, if a param is fp8,
# use its "fake" high precision dtype to find which params have same dtype with it.
# For example:
# Case 1:
# params = [p1(bf16), p2(bf16), p3(bf16), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 1, 2, 3],
# }
# Case 2:
# params = [p1(bf16), p2(fp8), p3(fp8), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 3],
# (torch.uint8, torch.float32): [1, 2],
# }
# We need these indices to load a non-native-fp8 checkpoint in native-fp8 mode.
offset = param_and_grad_dtype_to_offsets.get((param.dtype, grad_dtype), 0)
param_and_grad_dtype_to_offsets[(param.dtype, grad_dtype)] = offset + 1
indices = param_and_grad_dtype_to_indices.get((param_dtype, grad_dtype), [])
indices.append(offset)
param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)] = indices
if not config.calculate_per_token_loss:
target_gradient_scaling_factor = 1.0 / self.dp_cp_group.size()
if self.ddp_config.average_in_collective:
if self.ddp_config.num_distributed_optimizer_instances == 1:
# Collective is averaging gradients in collective with data_parallel_group.
assert (
gradient_scaling_factor / data_parallel_group.size()
== target_gradient_scaling_factor
)
else:
# For non-expert parameters, gradient_scaling_factor is 1.
# For expert parameters, gradient_scaling_factor is edp_size/dp_size.
assert (gradient_scaling_factor == 1) or (
gradient_scaling_factor
== (self.expt_dp_group.size() / self.dp_cp_group.size())
)
else:
assert gradient_scaling_factor == target_gradient_scaling_factor
# Allocate the grad buffers and map the grads.
buffers = []
for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items():
buffers.append(
_ParamAndGradBuffer(
self.ddp_config,
param_dtype,
grad_dtype,
params,
data_parallel_group,
self.bucket_size,
param_to_name,
gradient_scaling_factor,
param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)],
self.ddp_config.nccl_ub,
)
)
# In some scenarios, we want to put buckets from different buffers into a group so that
# their communication can be aggregated. For example, when there are both fp8 buffers
# and bf16 buffers in the model and vpp is enabled, each model chunk will have an fp8
# bucket and a bf16 bucket, which doubles the number of communication kernels, and
# because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back
# communications will prevent the overlap of the communication kernels with computation
# kernels.
# If bucketing is explicitly disabled, then put all buckets in a buffer into a single
# bucket group.
bucket_groups = partition_buckets(buffers, force_single_bucket_group=disable_bucketing)
if self.ddp_config.num_distributed_optimizer_instances > 1:
assert (
self.ddp_config.use_distributed_optimizer
), 'Partial DistOpt cannot be used without DistOpt'
communication_stream = torch.cuda.Stream(device=torch.cuda.current_device())
for bucket_group in bucket_groups:
bucket_group.inter_distributed_optimizer_instance_group = (
self.inter_dist_opt_group
)
bucket_group.communication_stream = communication_stream
# Set `next_param_gather_bucket_group` for different bucket groups by iterating through
# buckets in reverse order (since all-gathers happen in reverse order of buckets).
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
num_bucket_groups = len(bucket_groups)
for i in range(1, num_bucket_groups):
bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = (
bucket_groups[num_bucket_groups - i - 1]
)
# Create map from param to bucket group, used in pre_hook.
for bucket_group in bucket_groups:
for bucket in bucket_group.buckets:
for param in bucket.params_list:
self.param_to_bucket_group[param] = bucket_group
return buffers, bucket_groups
if config.calculate_per_token_loss:
assert (
not self.ddp_config.average_in_collective
), "Cannot average in collective when calculating per-token loss!"
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = 1.0
else:
# The goal is to scale reduced gradients by 1/dp_size.
# This can be achieved in two ways:
#
# Case 1: average_in_collective=True
# - Non-expert parameters:
# 1. No pre-scaling (gradient_scaling_factor=1.0)
# 2. Do average reduction over dp group (equals to sum then divide by dp_size)
# 3. Final result is scaled by 1/dp_size as desired
#
# - Expert parameters:
# 1. Scale by edp_size/dp_size before reduction
# 2. Do average reduction over edp group (equals to sum then divide by edp_size)
# 3. Resulted scaling: (edp_size/dp_size) * (1/edp_size) = 1/dp_size as desired
# (edp_size = expert data parallel world size)
#
# Case 2: average_in_collective=False
# - Both expert and non-expert parameters:
# 1. Scale gradients by 1/dp_size before reduction
# 2. Do sum reduction across data parallel ranks
# 3. Final result is scaled by 1/dp_size as desired
if self.ddp_config.average_in_collective:
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = self.expt_dp_group.size() / self.dp_cp_group.size()
else:
data_parallel_world_size = self.dp_cp_group.size()
gradient_scaling_factor = 1.0 / data_parallel_world_size
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
# Allocate the param+grad buffers for dense params' grads.
self.buffers, self.bucket_groups = _allocate_buffers_for_parameters(
dense_params, self.intra_dp_cp_group, gradient_scaling_factor=gradient_scaling_factor
)
# Allocate separate param+grad buffers for expert parallel params' grads.
self.expert_parallel_buffers, self.expert_parallel_bucket_groups = (
_allocate_buffers_for_parameters(
expert_parallel_params,
self.intra_expt_dp_group,
gradient_scaling_factor=expert_gradient_scaling_factor,
)
)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
if self.ddp_config.use_distributed_optimizer:
@torch.no_grad()
def unmap_weight_tensor(m):
if hasattr(m, 'weight_tensor'):
m.weight_tensor = None
self.module.apply(unmap_weight_tensor)
# Register backward hook.
# Accumulation function for the gradients need to be stored so they
# don't go out of scope.
self.grad_accs = []
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_backward_post_hook(param))
self.grad_accs.append(grad_acc)
self.use_forward_hook = (
self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather
)
self.remove_forward_pre_hook_handles = {}
if self.use_forward_hook:
self.enable_forward_pre_hook()
self.overlap_param_gather_with_optimizer_step = False
def enable_forward_pre_hook(self):
"""
Enable forward pre-hooks needed for param all-gather overlap with forward compute.
"""
assert self.use_forward_hook
assert len(self.remove_forward_pre_hook_handles) == 0
# Register forward pre-hook for all sub-modules.
for module in self.module.modules():
self.remove_forward_pre_hook_handles[module] = module.register_forward_pre_hook(
self._make_forward_pre_hook()
)
def disable_forward_pre_hook(self, param_sync: bool = True):
"""
Disable forward pre-hooks needed for param all-gather overlap with forward compute.
Skip synchronous param all-gather if `param_sync` is False.
"""
assert self.use_forward_hook
# De-register forward pre-hook for all sub-modules.
for module in self.module.modules():
assert self.remove_forward_pre_hook_handles[module] is not None
self.remove_forward_pre_hook_handles[module].remove()
del self.remove_forward_pre_hook_handles[module]
assert len(self.remove_forward_pre_hook_handles) == 0
# Force synchronize parameters.
if param_sync:
self.start_param_sync(force_sync=True)
def _make_forward_pre_hook(self):
"""
Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
when a module uses a parameter in a bucket with a still incomplete all-gather).
"""
def hook(module, *unused):
assert (
self.use_forward_hook
), "Should use pre-hook only when overlap_param_gather is True"
if is_graph_capturing():
return
# Make sure all parameters in this module have been all-gathered as necessary.
for param in module.parameters(recurse=False):
# Skip parameters without an associated buffer (such parameters have a
# .requires_grad field equal to False).
if param not in self.param_to_bucket_group:
continue
assert param.requires_grad
# If aligning param all-gather across pipeline stages, all-gather is dispatched
# by start_param_sync calls in core/pipeline_parallelism/schedules.py.
# If overlapping param all-gather with optimizer step, then all-gather has
# already been dispatched in optimizer step.
skip_next_bucket_dispatch = (
self.ddp_config.align_param_gather
or self.overlap_param_gather_with_optimizer_step
)
self.param_to_bucket_group[param].finish_param_sync(
skip_next_bucket_dispatch=skip_next_bucket_dispatch
)
return hook
def _make_backward_post_hook(self, param: torch.nn.Parameter):
"""
Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when
ready (i.e., when all grads in a bucket have been computed in all microbatches
in a batch).
"""
def hook(*unused):
if is_graph_capturing():
return
if param in self.param_to_bucket_group:
assert param.requires_grad
if self.ddp_config.overlap_grad_reduce:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
param.main_grad.add_(param.grad.data)
param.grad = None
if self.ddp_config.overlap_grad_reduce:
self.param_to_bucket_group[param].register_grad_ready(param)
return hook
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.is_last_microbatch = False
try:
yield
finally:
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.is_last_microbatch = True
def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False):
"""
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
calls; when overlap_param_gather is set to False, calls synchronous communication
ops. Can override this default behavior using flags below.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings.
force_dispatch (bool, optional): force dispatch regardless of other settings.
"""
if not force_sync:
# If overlapping param AG with optimizer step, AG should not be dispatched again
# in forward_backward_step.
if self.overlap_param_gather_with_optimizer_step and not force_dispatch:
return
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.start_param_sync(force_sync=force_sync)
# For MXFP8 params, we need to copy the all-gathered param data from the buffer to
# the param.data, since param buffer is not mapped to model params for MXFP8 case.
# The paramaters are cast from bf16 to MXFP8 during copy.
if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag:
assert (
not self.ddp_config.overlap_param_gather
), "MXFP8 param currently does not support DP AG overlap."
for bucket in bucket_group.buckets:
for param in bucket.params:
param_start, param_end = bucket.param_to_index[param]
param_slice = bucket.param_data.view(-1)[param_start:param_end]
param.data.copy_(param_slice.view(param.data.shape))
# All-gathered params are not needed after being copied to param.data.
# Zero out the grad buffer (shared with param buffer) for gradient accumulation.
bucket.grad_data.zero_()
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.start_grad_sync()
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.finish_grad_sync()
def scale_gradients(self, scaling_factor: float):
"""Scale all gradients inside the buffers by `scaling_factor`."""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.scale_gradients(scaling_factor)
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
if not getattr(self.config, 'external_cuda_graph', False):
# Don't reset grad_added_to_main_grad when CUDA Graph is used.
# Because in CUDA Graph it no longer has the opportunity to set it back
# to True, and there will be a double-GA.
for param in self.params_with_grad:
param.grad_added_to_main_grad = False
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.reset()
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.reset()
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
for param in self.module.parameters():
is_expert_parallel = not getattr(param, 'allreduce', True)
if is_expert_parallel:
data_parallel_group = self.expt_dp_group
else:
data_parallel_group = self.dp_cp_group
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_global_rank(data_parallel_group, 0),
group=data_parallel_group,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Optional
@dataclass
class DistributedDataParallelConfig:
"""Configuration for DistributedDataParallel."""
grad_reduce_in_fp32: bool = False
"""If true, reduce grads in fp32."""
overlap_grad_reduce: bool = False
"""If true, overlap grad all-reduce / reduce-scatter with backward compute."""
overlap_param_gather: bool = False
"""If true, overlap param all-gather with forward compute."""
align_param_gather: bool = False
"""If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each
PP stage will independently launch as needed.
"""
use_distributed_optimizer: bool = False
"""If true, issue reduce-scatter collectives to aggregate gradients and clean up
originally allocated model parameters, otherwise issue all-reduce collectives.
"""
num_distributed_optimizer_instances: int = 1
"""Sets the factor by which the DP domain is sharded to have the partial DistOpt
enabled. Defaults to 1, which means DistOpt is across entire DP domain.
"""
check_for_nan_in_grad: bool = False
"""If true, check for NaNs and Infs in gradients _before_ communication collective."""
check_for_large_grads: bool = False
"""If true, check for unexpectedly large gradients _before_ communication collective."""
bucket_size: Optional[int] = None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound)."""
pad_buckets_for_high_nccl_busbw: bool = False
"""If true, make sure the bucket size is divisible by a large power of 2 (2^16) to
ensure NCCL collectives have high bus bandwidth at large DP counts, since NCCL
message size (which for ring algorithms is bucket_size / dp_size) apparently needs
to be divisible by a power of 2 for high busbw."""
average_in_collective: bool = False
"""If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective."""
fp8_param_gather: bool = False
"""If true, keep the compute param in fp8 (do not use any other intermediate dtype) and
perform the param all-gather in fp8."""
reuse_grad_buf_for_mxfp8_param_ag: bool = False
"""If true, reuse the grad buffer for param AG when using mxfp8 recipe. Should be
set to True only when fp8_recipe is mxfp8 and fp8_param_gather is True."""
use_custom_fsdp: bool = False
"""If true, use the FSDP code path for DDP."""
data_parallel_sharding_strategy: str = 'no_shard'
"""Sharding strategy for FSDP. Valid values are 'no_shard', 'optim',
'optim_grads', 'optim_grads_params'."""
gradient_reduce_div_fusion: bool = True
"""If true, perform gradient reduce and division fusion."""
suggested_communication_unit_size: int = None
"""Specifies the number of elements to communicate at once during
FSDP (Fully Sharded Data Parallel) operations.
This flag also affects FSDP all-gather prefetch behavior. Setting a larger
value increases the communication buffer size, while a smaller value
disables prefetching and may degrade performance. Adjust this value
based on your system's memory and performance requirements."""
preserve_fp32_weights: bool = True
"""If true, preserve fp32 weights in the custom FSDP ParamAndGradBuffer."""
keep_fp8_transpose_cache_when_using_custom_fsdp: bool = False
"""If true, keep the fp8 transpose cache when using custom FSDP."""
nccl_ub: bool = False
"""If true, allocate and register NCCL userbuffer for param and grad buffer.
This flag enables SM efficient nccl algorithm that could improve the performance
of FSDP and DP with comm_overlap. This flag will be much more effective when used
together with sharp.
The follwoing will be the expected number of SM usage for various cases.
(Note that this is just a reference number and the number of SM usage could vary
on message size, communication domain size and nccl version.)
----------------------------------------------------------
| Communication domain | use_sharp | SM usage of "AG/RS" |
|----------------------|-----------|---------------------|
| NVL | N/A | 4 / 5 |
| NVL+IB | False | 16 / 16 |
| NVL+IB | True | 6 / 6 |
| IB | False | 1 / 4 |
| IB | True | 1 / 1 |
----------------------------------------------------------
"""
fsdp_double_buffer: bool = False
"""If true, use persistently allocated double buffers for the
temporary memory needed in the custom FSDP communications.
This option will cause additional memory overhead, however, it is necessary for
to register user buffer (nccl_ub=True) for the custom FSDP.
This option will be automatically set to True when nccl_ub=True.
"""
def __post_init__(self):
"""Check the validity of the config."""
if self.reuse_grad_buf_for_mxfp8_param_ag:
assert self.fp8_param_gather, "Reuse grad buffer only when keeping params in MXFP8."
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List, Optional, Union
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
try:
from torch.distributed._tensor import DTensor, distribute_tensor
HAVE_DTENSOR = True
except ImportError:
HAVE_DTENSOR = False
from .. import parallel_state
from ..transformer.moe.moe_utils import get_updated_expert_bias
from ..transformer.transformer_config import TransformerConfig
from ..utils import get_attr_wrapped_model, get_model_config
def _get_main_grad_attr(param: torch.nn.Parameter, use_custom_fsdp: bool = False):
if use_custom_fsdp:
return "fsdp_managed_main_grad"
if hasattr(param, "main_grad"):
return "main_grad"
return "grad"
def _unshard_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor:
"""
Unshards the input tensor if it is a DTensor and otherwise returns the
tensor unmodified.
Args:
tensor (Union[torch.Tensor, DTensor]): The tensor to potentially unshard.
Returns:
An unsharded version of the input tensor if it is a DTensor, or the
input tensor unmodified if it is not a DTensor.
"""
if HAVE_DTENSOR and isinstance(tensor, DTensor):
unsharded_tensor = tensor.full_tensor()
for k, v in vars(tensor).items():
setattr(unsharded_tensor, k, v)
return unsharded_tensor
return tensor
def _reshard_if_dtensor(
tensor_to_shard: torch.Tensor, reference_tensor: Union[torch.Tensor, "DTensor"]
) -> Union[torch.Tensor, "DTensor"]:
"""
Reshards the input tensor to match the sharding configuration of the
reference tensor if the reference tensor is a DTensor. Otherwise, returns
the reference tensor unmodified.
Args:
tensor_to_shard (torch.Tensor): The tensor to be potentially sharded.
reference_tensor (Union[torch.Tensor, DTensor]): The reference tensor
for the sharding configuration.
Returns:
Union[torch.Tensor, DTensor]: The sharded tensor matching the reference tensor's
configuration, or the reference tensor itself if it is not a DTensor.
"""
if HAVE_DTENSOR and isinstance(reference_tensor, DTensor):
sharded_tensor = distribute_tensor(
tensor_to_shard,
device_mesh=reference_tensor.device_mesh,
placements=reference_tensor.placements,
)
for k, v in vars(reference_tensor).items():
setattr(sharded_tensor, k, v)
return sharded_tensor
return reference_tensor
def _allreduce_conditional_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce conditional embedding grads.
Reduce grads across all the pp stages to ensure that parameters of the conditional embedders
(e.g., timestep embedder, FPS embedder, label embedder) stay in sync.
This is for the models with replicated embedders on each PP / VPP rank, like diffusion models.
"""
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and getattr(
config, "has_cond_embedder", False
):
grads_dict = {}
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if param.requires_grad and getattr(param, 'pipeline_parallel', False):
grad = param.main_grad
if name in grads_dict:
# Add all the virtual PP rank's gradients to
# the first local virtual PP rank.
grads_dict[name][0].add_(grad)
# Append to the end for later update after cross-rank reduce.
grads_dict[name].append(grad)
else:
grads_dict[name] = [grad]
if grads_dict:
# All-reduce the gradient on the first VPP rank.
grads = [param_grad[0] for _, param_grad in grads_dict.items()]
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_pipeline_model_parallel_group()
)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
# Update the gradients on other VPP ranks.
for grads in grads_dict.values():
for grad in grads[1:]:
grad.copy_(grads[0])
def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync.
"""
if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and parallel_state.get_embedding_group().size() > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]
ddp_config = model_module.ddp_config
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to allreduce grads of embedding in the embedding group in these cases.
if model_module.share_embeddings_and_output_weights or getattr(config, 'mtp_num_layers', 0):
weight = model_module.shared_embedding_or_output_weight()
grad_attr = _get_main_grad_attr(weight, ddp_config.use_custom_fsdp)
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
# When the embedding is frozen, the grad is None.
if grad is None:
return
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce position_embeddings grad across encoder and decoder stages to ensure that position
embeddings parameters stay in sync.
"""
if (
parallel_state.is_rank_in_position_embedding_group()
and parallel_state.get_position_embedding_group().size() > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]
ddp_config = model_module.ddp_config
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
assert hasattr(model_module, 'position_embeddings')
weight = model_module.position_embeddings.weight
grad_attr = _get_main_grad_attr(weight, ddp_config.use_custom_fsdp)
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce both word and position embeddings.
"""
_allreduce_word_embedding_grads(model, config)
_allreduce_position_embedding_grads(model, config)
def _update_router_expert_bias(model: List[torch.nn.Module], config: TransformerConfig):
"""
Update the expert bias of the router for a global batch.
This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks
"""
tokens_per_expert_list = []
expert_bias_list = []
for model_chunk in model:
for module in get_attr_wrapped_model(model_chunk, 'modules')():
if hasattr(module, 'expert_bias'):
tokens_per_expert_list.append(module.local_tokens_per_expert)
expert_bias_list.append(module.expert_bias)
# For hybrid models with both MoE and Dense layers, this list can be empty.
if len(expert_bias_list) == 0:
return
stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0)
stacked_expert_bias = torch.stack(expert_bias_list, dim=0)
stacked_updated_expert_bias = get_updated_expert_bias(
stacked_tokens_per_expert, stacked_expert_bias, config.moe_router_bias_update_rate
)
for tokens_per_expert, expert_bias, updated_expert_bias in zip(
tokens_per_expert_list, expert_bias_list, stacked_updated_expert_bias
):
tokens_per_expert.zero_()
expert_bias.copy_(updated_expert_bias)
def _allreduce_non_tensor_model_parallel_grads(
model: List[torch.nn.Module], config: TransformerConfig
):
"""
All-reduce both layernorm grads (for sequence parallelism) and
gradients from modules with average_gradients_across_tp_domain=True
across tensor-model-parallel ranks.
"""
if parallel_state.get_tensor_model_parallel_world_size() <= 1:
return
params_sum = []
grads_sum = []
params_avg = []
grads_avg = []
for model_chunk in model:
ddp_config = model_chunk.ddp_config
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if param.requires_grad:
# Check if this param needs average reduction (average_gradients_across_tp_domain)
if getattr(param, "average_gradients_across_tp_domain", False):
params_avg.append(param)
grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
grad = getattr(param, grad_attr)
grad = _unshard_if_dtensor(grad)
grads_avg.append(grad.data)
# Check if this param needs sum reduction (sequence parallel or qk_layernorm)
elif (config.sequence_parallel and getattr(param, "sequence_parallel", False)) or (
config.qk_layernorm and ("q_layernorm" in name or "k_layernorm" in name)
):
params_sum.append(param)
grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
grad = getattr(param, grad_attr)
grad = _unshard_if_dtensor(grad)
grads_sum.append(grad.data)
# Loop grads and perform correct all-reduce
for params, grads, all_reduce_op in zip(
[params_sum, params_avg],
[grads_sum, grads_avg],
[torch.distributed.ReduceOp.SUM, torch.distributed.ReduceOp.AVG],
):
if grads:
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, op=all_reduce_op, group=parallel_state.get_tensor_model_parallel_group()
)
for param, buf, synced in zip(
params, grads, _unflatten_dense_tensors(coalesced, grads)
):
buf.copy_(synced)
grad_attr = _get_main_grad_attr(param, ddp_config.use_custom_fsdp)
orig_grad = getattr(param, grad_attr)
setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))
"""
This is an alias to _allreduce_non_tensor_model_parallel_grads that we must
maintain for legacy tests. We can remove this proxy in mcore 0.14.
"""
_allreduce_layernorm_grads = _allreduce_non_tensor_model_parallel_grads
def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
config = get_model_config(model[0])
# All-reduce / reduce-scatter across DP replicas.
if config.timers is not None:
config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
for model_chunk in model:
model_chunk.finish_grad_sync()
if config.timers is not None:
config.timers('all-grads-sync').stop()
# All-reduce t_embedder grads (for pp & vpp of DiT).
if config.timers is not None:
config.timers('conditional-embedder-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_conditional_embedding_grads(model, config)
if config.timers is not None:
config.timers('conditional-embedder-grads-all-reduce').stop()
# All-reduce layer-norm grads (for sequence parallelism) and non-tensor parallel modules.
if config.timers is not None:
config.timers('non-tensor-parallel-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_non_tensor_model_parallel_grads(model, config)
if config.timers is not None:
config.timers('non-tensor-parallel-grads-all-reduce').stop()
# All-reduce embedding grads (for pipeline parallelism).
if config.timers is not None:
config.timers('embedding-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_embedding_grads(model, config)
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()
if config.moe_router_enable_expert_bias:
_update_router_expert_bias(model, config)
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if num_tokens is not None:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
last_rank = parallel_state.get_pipeline_model_parallel_last_rank()
pp_group = parallel_state.get_pipeline_model_parallel_group()
if not isinstance(last_rank, list):
assert not isinstance(last_rank, list)
last_rank = [last_rank]
assert not isinstance(pp_group, list)
pp_group = [pp_group]
# need to do a broadcast for every pp group, even though num_tokens should be the same.
num_tokens_list = []
for lr, group in zip(last_rank, pp_group):
torch.distributed.broadcast(num_tokens, src=lr, group=group)
num_tokens_list.append(torch.clone(num_tokens))
assert all(x.item() == num_tokens_list[0] for x in num_tokens_list)
# all-reduce across DP ranks.
torch.distributed.all_reduce(
num_tokens, group=parallel_state.get_data_parallel_group(with_context_parallel=True)
)
for model_chunk in model:
if num_tokens > 0:
scaling = 1.0 / num_tokens
model_chunk.scale_gradients(scaling)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment