Commit 160bf237 authored by wangxj's avatar wangxj
Browse files

更新0.12

parent b01809dd
Pipeline #2448 failed with stages
This diff is collapsed.
...@@ -70,7 +70,7 @@ class _BaseDataParallel(MegatronModule): ...@@ -70,7 +70,7 @@ class _BaseDataParallel(MegatronModule):
""" """
pass pass
def state_dict(self, prefix='', keep_vars=False): def state_dict(self, prefix='', keep_vars=False, destination=None):
""" """
Returns a dictionary containing references to the whole state of the Returns a dictionary containing references to the whole state of the
wrapped module. wrapped module.
...@@ -79,7 +79,7 @@ class _BaseDataParallel(MegatronModule): ...@@ -79,7 +79,7 @@ class _BaseDataParallel(MegatronModule):
Keys are corresponding parameter and buffer names. Parameters and buffers Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included. set to None are not included.
""" """
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars) return self.module.state_dict(prefix=prefix, keep_vars=keep_vars, destination=destination)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
""" """
......
...@@ -7,8 +7,10 @@ import torch ...@@ -7,8 +7,10 @@ import torch
from .. import parallel_state from .. import parallel_state
from ..config_logger import has_config_logger_enabled, log_config_to_disk from ..config_logger import has_config_logger_enabled, log_config_to_disk
from ..fp8_utils import is_float8tensor
from ..transformer.cuda_graphs import is_graph_capturing
from ..transformer.transformer_config import TransformerConfig from ..transformer.transformer_config import TransformerConfig
from ..utils import is_float8tensor, log_single_rank from ..utils import log_single_rank
from .data_parallel_base import _BaseDataParallel from .data_parallel_base import _BaseDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig from .distributed_data_parallel_config import DistributedDataParallelConfig
from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets
...@@ -151,12 +153,25 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -151,12 +153,25 @@ class DistributedDataParallel(_BaseDataParallel):
with_context_parallel=True with_context_parallel=True
) )
if self.ddp_config.average_in_collective: if self.ddp_config.average_in_collective:
# Collective is averaging gradients in collective with data_parallel_group. if self.ddp_config.num_distributed_optimizer_instances == 1:
assert ( # Collective is averaging gradients in collective with data_parallel_group.
gradient_scaling_factor assert (
/ parallel_state.get_data_parallel_world_size(with_context_parallel=True) gradient_scaling_factor
== target_gradient_scaling_factor / torch.distributed.get_world_size(group=data_parallel_group)
) == target_gradient_scaling_factor
)
else:
# For non-expert parameters, gradient_scaling_factor is 1.
# For expert parameters, gradient_scaling_factor is edp_size/dp_size.
assert (gradient_scaling_factor == 1) or (
gradient_scaling_factor
== (
parallel_state.get_expert_data_parallel_world_size()
/ parallel_state.get_data_parallel_world_size(
with_context_parallel=True
)
)
)
else: else:
assert gradient_scaling_factor == target_gradient_scaling_factor assert gradient_scaling_factor == target_gradient_scaling_factor
...@@ -189,6 +204,9 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -189,6 +204,9 @@ class DistributedDataParallel(_BaseDataParallel):
bucket_groups = partition_buckets(buffers, force_single_bucket_group=disable_bucketing) bucket_groups = partition_buckets(buffers, force_single_bucket_group=disable_bucketing)
if self.ddp_config.num_distributed_optimizer_instances > 1: if self.ddp_config.num_distributed_optimizer_instances > 1:
assert (
parallel_state.get_expert_model_parallel_world_size() == 1
), "Partial DistOpt cannot support MoE models with expert parallelism."
assert ( assert (
self.ddp_config.use_distributed_optimizer self.ddp_config.use_distributed_optimizer
), 'Partial DistOpt cannot be used without DistOpt' ), 'Partial DistOpt cannot be used without DistOpt'
...@@ -220,10 +238,31 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -220,10 +238,31 @@ class DistributedDataParallel(_BaseDataParallel):
gradient_scaling_factor = 1.0 gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = 1.0 expert_gradient_scaling_factor = 1.0
else: else:
# The goal is to scale reduced gradients by 1/dp_size.
# This can be achieved in two ways:
#
# Case 1: average_in_collective=True
# - Non-expert parameters:
# 1. No pre-scaling (gradient_scaling_factor=1.0)
# 2. Do average reduction over dp group (equals to sum then divide by dp_size)
# 3. Final result is scaled by 1/dp_size as desired
#
# - Expert parameters:
# 1. Scale by edp_size/dp_size before reduction
# 2. Do average reduction over edp group (equals to sum then divide by edp_size)
# 3. Resulted scaling: (edp_size/dp_size) * (1/edp_size) = 1/dp_size as desired
# (edp_size = expert data parallel world size)
#
# Case 2: average_in_collective=False
# - Both expert and non-expert parameters:
# 1. Scale gradients by 1/dp_size before reduction
# 2. Do sum reduction across data parallel ranks
# 3. Final result is scaled by 1/dp_size as desired
if self.ddp_config.average_in_collective: if self.ddp_config.average_in_collective:
gradient_scaling_factor = 1.0 gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = ( expert_gradient_scaling_factor = (
1.0 / parallel_state.get_expert_model_parallel_world_size() parallel_state.get_expert_data_parallel_world_size()
/ parallel_state.get_data_parallel_world_size(with_context_parallel=True)
) )
else: else:
data_parallel_world_size = parallel_state.get_data_parallel_world_size( data_parallel_world_size = parallel_state.get_data_parallel_world_size(
...@@ -297,9 +336,10 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -297,9 +336,10 @@ class DistributedDataParallel(_BaseDataParallel):
self._make_forward_pre_hook() self._make_forward_pre_hook()
) )
def disable_forward_pre_hook(self): def disable_forward_pre_hook(self, param_sync: bool = True):
""" """
Disable forward pre-hooks needed for param all-gather overlap with forward compute. Disable forward pre-hooks needed for param all-gather overlap with forward compute.
Skip synchronous param all-gather if `param_sync` is False.
""" """
assert self.use_forward_hook assert self.use_forward_hook
# De-register forward pre-hook for all sub-modules. # De-register forward pre-hook for all sub-modules.
...@@ -310,7 +350,8 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -310,7 +350,8 @@ class DistributedDataParallel(_BaseDataParallel):
assert len(self.remove_forward_pre_hook_handles) == 0 assert len(self.remove_forward_pre_hook_handles) == 0
# Force synchronize parameters. # Force synchronize parameters.
self.start_param_sync(force_sync=True) if param_sync:
self.start_param_sync(force_sync=True)
def _make_forward_pre_hook(self): def _make_forward_pre_hook(self):
""" """
...@@ -323,6 +364,9 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -323,6 +364,9 @@ class DistributedDataParallel(_BaseDataParallel):
self.use_forward_hook self.use_forward_hook
), "Should use pre-hook only when overlap_param_gather is True" ), "Should use pre-hook only when overlap_param_gather is True"
if is_graph_capturing():
return
# Make sure all parameters in this module have been all-gathered as necessary. # Make sure all parameters in this module have been all-gathered as necessary.
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
# Skip parameters without an associated buffer (such parameters have a # Skip parameters without an associated buffer (such parameters have a
...@@ -353,6 +397,9 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -353,6 +397,9 @@ class DistributedDataParallel(_BaseDataParallel):
""" """
def hook(*unused): def hook(*unused):
if is_graph_capturing():
return
if param in self.param_to_bucket_group: if param in self.param_to_bucket_group:
assert param.requires_grad assert param.requires_grad
if self.ddp_config.overlap_grad_reduce: if self.ddp_config.overlap_grad_reduce:
......
...@@ -33,13 +33,22 @@ class DistributedDataParallelConfig: ...@@ -33,13 +33,22 @@ class DistributedDataParallelConfig:
""" """
check_for_nan_in_grad: bool = False check_for_nan_in_grad: bool = False
""" If true, check for NaNs in gradients _before_ communication collective.""" """If true, check for NaNs and Infs in gradients _before_ communication collective."""
check_for_large_grads: bool = False
"""If true, check for unexpectedly large gradients _before_ communication collective."""
bucket_size: Optional[int] = None bucket_size: Optional[int] = None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default """Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound).""" buckets to ensure collectives do not become latency-bound)."""
pad_buckets_for_high_nccl_busbw: bool = False
"""If true, make sure the bucket size is divisible by a large power of 2 (2^16) to
ensure NCCL collectives have high bus bandwidth at large DP counts, since NCCL
message size (which for ring algorithms is bucket_size / dp_size) apparently needs
to be divisible by a power of 2 for high busbw."""
average_in_collective: bool = False average_in_collective: bool = False
"""If true, compute average in collective directly, as opposed to dividing by the """If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective.""" dp_size first and then computing sum in the collective."""
...@@ -47,3 +56,23 @@ class DistributedDataParallelConfig: ...@@ -47,3 +56,23 @@ class DistributedDataParallelConfig:
fp8_param_gather: bool = False fp8_param_gather: bool = False
"""If true, keep the compute param in fp8 (do not use any other intermediate dtype) and """If true, keep the compute param in fp8 (do not use any other intermediate dtype) and
perform the param all-gather in fp8.""" perform the param all-gather in fp8."""
use_custom_fsdp: bool = False
"""If true, use the FSDP code path for DDP."""
data_parallel_sharding_strategy: str = 'no_shard'
"""Sharding strategy for FSDP. Valid values are 'no_shard', 'optim',
'optim_grads', 'optim_grads_params'."""
gradient_reduce_div_fusion: bool = True
"""If true, perform gradient reduce and division fusion."""
suggested_communication_unit_size: int = 400_000_000
"""When batch communication is needed across multiple buckets,
this environment variable guides the size of communication unit size."""
preserve_fp32_weights: bool = True
"""If true, preserve fp32 weights in the custom FSDP ParamAndGradBuffer."""
keep_fp8_transpose_cache_when_using_custom_fsdp: bool = False
"""If true, keep the fp8 transpose cache when using custom FSDP."""
...@@ -13,10 +13,19 @@ except ImportError: ...@@ -13,10 +13,19 @@ except ImportError:
HAVE_DTENSOR = False HAVE_DTENSOR = False
from .. import parallel_state from .. import parallel_state
from ..transformer.moe.moe_utils import get_updated_expert_bias
from ..transformer.transformer_config import TransformerConfig from ..transformer.transformer_config import TransformerConfig
from ..utils import get_attr_wrapped_model, get_model_config from ..utils import get_attr_wrapped_model, get_model_config
def _get_main_grad_attr(param: torch.nn.Parameter, use_custom_fsdp: bool = False):
if use_custom_fsdp:
return "fsdp_managed_main_grad"
if hasattr(param, "main_grad"):
return "main_grad"
return "grad"
def _unshard_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor: def _unshard_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor:
""" """
Unshards the input tensor if it is a DTensor and otherwise returns the Unshards the input tensor if it is a DTensor and otherwise returns the
...@@ -126,10 +135,11 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf ...@@ -126,10 +135,11 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
else: # We do not support an interleaved schedule for models with encoders yet. else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0] model_module = model[0]
ddp_config = model_module.ddp_config
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights: if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight() weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" grad_attr = _get_main_grad_attr(weight, ddp_config.use_custom_fsdp)
orig_grad = getattr(weight, grad_attr) orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad) grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
...@@ -152,10 +162,11 @@ def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: Tr ...@@ -152,10 +162,11 @@ def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: Tr
else: # We do not support an interleaved schedule for models with encoders yet. else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0] model_module = model[0]
ddp_config = model_module.ddp_config
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
assert hasattr(model_module, 'position_embeddings') assert hasattr(model_module, 'position_embeddings')
weight = model_module.position_embeddings.weight weight = model_module.position_embeddings.weight
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" grad_attr = _get_main_grad_attr(weight, ddp_config.use_custom_fsdp)
orig_grad = getattr(weight, grad_attr) orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad) grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group()) torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
...@@ -184,14 +195,13 @@ def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: Transformer ...@@ -184,14 +195,13 @@ def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: Transformer
grads = [] grads = []
for model_chunk in model: for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')(): for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if ( if param.requires_grad and (
param.requires_grad getattr(param, 'sequence_parallel', False)
and getattr(param, 'sequence_parallel', False)
or 'q_layernorm' in name or 'q_layernorm' in name
or 'k_layernorm' in name or 'k_layernorm' in name
): ):
params.append(param) params.append(param)
grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad" grad_attr = _get_main_grad_attr(param, config.use_custom_fsdp)
grad = getattr(param, grad_attr) grad = getattr(param, grad_attr)
grad = _unshard_if_dtensor(grad) grad = _unshard_if_dtensor(grad)
grads.append(grad.data) grads.append(grad.data)
...@@ -204,11 +214,39 @@ def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: Transformer ...@@ -204,11 +214,39 @@ def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: Transformer
params, grads, _unflatten_dense_tensors(coalesced, grads) params, grads, _unflatten_dense_tensors(coalesced, grads)
): ):
buf.copy_(synced) buf.copy_(synced)
grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad" grad_attr = _get_main_grad_attr(param, config.use_custom_fsdp)
orig_grad = getattr(param, grad_attr) orig_grad = getattr(param, grad_attr)
setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad)) setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))
def _update_router_expert_bias(model: List[torch.nn.Module], config: TransformerConfig):
"""
Update the expert bias of the router for a global batch.
This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks
"""
tokens_per_expert_list = []
expert_bias_list = []
for model_chunk in model:
for module in get_attr_wrapped_model(model_chunk, 'modules')():
if hasattr(module, 'expert_bias'):
tokens_per_expert_list.append(module.local_tokens_per_expert)
expert_bias_list.append(module.expert_bias)
# For hybrid models with both MoE and Dense layers, this list can be empty.
if len(expert_bias_list) == 0:
return
stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0)
stacked_expert_bias = torch.stack(expert_bias_list, dim=0)
stacked_updated_expert_bias = get_updated_expert_bias(
stacked_tokens_per_expert, stacked_expert_bias, config.moe_router_bias_update_rate
)
for tokens_per_expert, expert_bias, updated_expert_bias in zip(
tokens_per_expert_list, expert_bias_list, stacked_updated_expert_bias
):
tokens_per_expert.zero_()
expert_bias.copy_(updated_expert_bias)
def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None): def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None):
""" """
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
...@@ -253,6 +291,9 @@ def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torc ...@@ -253,6 +291,9 @@ def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torc
if config.timers is not None: if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop() config.timers('embedding-grads-all-reduce').stop()
if config.moe_router_enable_expert_bias:
_update_router_expert_bias(model, config)
# normalize gradients for per-token loss normalization. # normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number # if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch. # will be the total number of non-padded tokens in the global batch.
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
import logging import logging
import math import math
import warnings
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from functools import partial
from typing import Dict, List, Optional from typing import Dict, List, Optional
import torch import torch
...@@ -11,7 +13,8 @@ from torch.distributed import _coalescing_manager ...@@ -11,7 +13,8 @@ from torch.distributed import _coalescing_manager
from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.rerun_state_machine import get_rerun_state_machine
from ..utils import is_float8tensor, is_torch_min_version, log_on_each_pipeline_stage from ..fp8_utils import is_float8tensor
from ..utils import is_torch_min_version, log_on_each_pipeline_stage
from .distributed_data_parallel_config import DistributedDataParallelConfig from .distributed_data_parallel_config import DistributedDataParallelConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -149,21 +152,43 @@ class _ParamAndGradBucketGroup: ...@@ -149,21 +152,43 @@ class _ParamAndGradBucketGroup:
self.params_with_grad = set() self.params_with_grad = set()
self.is_last_microbatch = True self.is_last_microbatch = True
def check_for_nan_in_grad(self): def check_grads(self, check_for_nan_or_inf, check_for_large):
""" """
Make sure norm of grads in bucket are not NaN prior to data-parallel Make sure norm of grads in bucket are not NaN prior to data-parallel
all-reduce / reduce-scatter. all-reduce / reduce-scatter.
""" """
rerun_state_machine = get_rerun_state_machine() rerun_state_machine = get_rerun_state_machine()
for i in range(len(self.buckets)): for i in range(len(self.buckets)):
rerun_state_machine.validate_result( grad_norm = self.buckets[i].grad_data.norm(p=2)
result=self.buckets[i].grad_data.norm(p=2), # check for NaN, Inf and unexpectedly large grads
rejection_func=torch.isnan, if check_for_nan_or_inf:
message=f"found NaN in local grad norm for bucket #{i} " rerun_state_machine.validate_result(
f"in backward pass before data-parallel communication collective", result=grad_norm,
tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward rejection_func=torch.isnan,
fatal=True, message=f"found NaN in local grad norm for bucket #{i} "
) f"in backward pass before data-parallel communication collective",
tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward
fatal=True,
)
rerun_state_machine.validate_result(
result=grad_norm,
rejection_func=torch.isinf,
message=f"found Inf in local grad norm for bucket #{i} "
f"in backward pass before data-parallel communication collective",
tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward
fatal=True,
)
if check_for_large:
rerun_state_machine.validate_result(
result=grad_norm,
rejection_func=partial(
rerun_state_machine.is_unexpectedly_large, threshold=10, context="grads"
),
message=f"found unexpected large grads in bucket #{i} "
f"in backward pass before data-parallel communication collective",
tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward
fatal=False,
)
def start_param_sync(self, force_sync: bool = False): def start_param_sync(self, force_sync: bool = False):
""" """
...@@ -239,9 +264,17 @@ class _ParamAndGradBucketGroup: ...@@ -239,9 +264,17 @@ class _ParamAndGradBucketGroup:
if self.param_gather_handle is not None: if self.param_gather_handle is not None:
self.param_gather_handle.wait() self.param_gather_handle.wait()
self.param_gather_handle = None self.param_gather_handle = None
# Dispatch next bucket's asynchronous param AG. # Dispatch next bucket's asynchronous param AG only if it has not been dispatched yet.
if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch: if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch:
self.next_param_gather_bucket_group.start_param_sync() if self.next_param_gather_bucket_group.param_gather_dispatched:
warnings.warn(
"The next bucket's parameter all-gather operation has already been "
"dispatched. This may be caused by a mismatch between the order of "
"parameter registration and forward pass execution, which will "
"hurt the communication-computation overlap performance."
)
else:
self.next_param_gather_bucket_group.start_param_sync()
def start_grad_sync(self): def start_grad_sync(self):
""" """
...@@ -256,8 +289,11 @@ class _ParamAndGradBucketGroup: ...@@ -256,8 +289,11 @@ class _ParamAndGradBucketGroup:
self.grad_reduce_handle is None self.grad_reduce_handle is None
), 'Should not have multiple communication calls outstanding at once' ), 'Should not have multiple communication calls outstanding at once'
if self.ddp_config.check_for_nan_in_grad: if self.ddp_config.check_for_nan_in_grad or self.ddp_config.check_for_large_grads:
self.check_for_nan_in_grad() self.check_grads(
check_for_nan_or_inf=self.ddp_config.check_for_nan_in_grad,
check_for_large=self.ddp_config.check_for_large_grads,
)
# gradient_scaling_factor already takes into account whether we are computing # gradient_scaling_factor already takes into account whether we are computing
# an average or sum in the data-parallel collective. # an average or sum in the data-parallel collective.
...@@ -270,13 +306,12 @@ class _ParamAndGradBucketGroup: ...@@ -270,13 +306,12 @@ class _ParamAndGradBucketGroup:
if self.ddp_config.average_in_collective: if self.ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG reduce_op = torch.distributed.ReduceOp.AVG
# Stream synchronization logic of the CUDA streams that is # We use the following stream synchronization for the gradient reduction
# implemented below for the gradient reduction within and across # within and across DistOpt instances.
# distributed optimizer instances.
# Compute Stream - -------------Gradient Compute------------------- # Compute Stream: -------------Gradient compute-------------------
# Comm. Stream - ------(wait for nccl)-----(wait for nccl)------- # Comm. Stream: ------(wait for NCCL)-----(wait for NCCL)-------
# NCCL Stream - -------RS------ -------AR------ # NCCL Stream: -------RS------ -------AR------
# Use async communications only when overlap_grad_reduce is True. # Use async communications only when overlap_grad_reduce is True.
async_op = ( async_op = (
...@@ -287,13 +322,13 @@ class _ParamAndGradBucketGroup: ...@@ -287,13 +322,13 @@ class _ParamAndGradBucketGroup:
self.ddp_config.num_distributed_optimizer_instances > 1 self.ddp_config.num_distributed_optimizer_instances > 1
and self.ddp_config.overlap_grad_reduce and self.ddp_config.overlap_grad_reduce
): ):
# Assign a communication stream if we use partial DP DistOpt and we # Assign a communication stream if we have multiple DistOpt instances and we
# need to overlap communication # need to overlap communication.
stream_context = torch.cuda.stream(self.communication_stream) stream_context = torch.cuda.stream(self.communication_stream)
# The RS/AR communication stream needs to wait for the default stream # The RS/AR communication stream needs to wait for the default stream
# to complete its gradient computation before launching the next # to complete its gradient computation before launching the next
# gradient reduction collective # gradient reduction collective.
self.communication_stream.wait_stream(torch.cuda.default_stream()) self.communication_stream.wait_stream(torch.cuda.default_stream())
else: else:
stream_context = nullcontext() stream_context = nullcontext()
...@@ -314,24 +349,22 @@ class _ParamAndGradBucketGroup: ...@@ -314,24 +349,22 @@ class _ParamAndGradBucketGroup:
local_data_view, local_data_view,
bucket.grad_data, bucket.grad_data,
op=reduce_op, op=reduce_op,
group=self.intra_distributed_optimizer_instance_group, group=communication_group,
async_op=async_op, async_op=async_op,
) )
else: else:
torch.distributed.all_reduce( torch.distributed.all_reduce(
bucket.grad_data, bucket.grad_data, op=reduce_op, group=communication_group, async_op=async_op
op=reduce_op,
group=self.data_parallel_group,
async_op=async_op,
) )
# When enabling partial DP domain DistOpt, we need to All-Reduce across all partial domains # With multiple DistOpt instances, we need to all-reduce across instances.
if ( if (
self.ddp_config.use_distributed_optimizer self.ddp_config.use_distributed_optimizer
and self.ddp_config.num_distributed_optimizer_instances > 1 and self.ddp_config.num_distributed_optimizer_instances > 1
): ):
# Create a new coalescing facility for the inter partial DP-AllReduce here assert self.inter_distributed_optimizer_instance_group is not None
# Create a new coalescing manager for the inter-instance all-reduce.
with stream_context, _coalescing_manager( with stream_context, _coalescing_manager(
self.inter_distributed_optimizer_instance_group, async_ops=async_op self.inter_distributed_optimizer_instance_group, async_ops=async_op
) as cm: ) as cm:
...@@ -366,13 +399,13 @@ class _ParamAndGradBucketGroup: ...@@ -366,13 +399,13 @@ class _ParamAndGradBucketGroup:
communication call to complete. When ddp_config.overlap_grad_reduce is set to False, communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
makes synchronous call. makes synchronous call.
""" """
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
self.param_gather_dispatched = False self.param_gather_dispatched = False
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
if not self.ddp_config.overlap_grad_reduce: if not self.ddp_config.overlap_grad_reduce:
self.start_grad_sync() self.start_grad_sync()
return return
# When using partial DP DistOpt, we don't need to sync as we launch comms on a separate # When using multiple DistOpt instances, we don't need to sync here as we launch
# communication stream # communications on a separate communication stream.
if self.ddp_config.num_distributed_optimizer_instances > 1: if self.ddp_config.num_distributed_optimizer_instances > 1:
torch.cuda.default_stream().wait_stream(self.communication_stream) torch.cuda.default_stream().wait_stream(self.communication_stream)
return return
...@@ -474,7 +507,15 @@ class _ParamAndGradBuffer: ...@@ -474,7 +507,15 @@ class _ParamAndGradBuffer:
# This also helps cuBLAS pick more efficient algorithms for GEMMs. # This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte # We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision). # aligned (128 values since params and grads use >= 16-bit precision).
return _pad(bucket_end_index, math.lcm(self.data_parallel_world_size, 128)) if self.ddp_config.pad_buckets_for_high_nccl_busbw:
# Make sure the bucket size is divisible by a large power of 2 (2^16) to
# ensure NCCL collectives have high bus bandwidth at large DP counts,
# since NCCL message size (which for ring algorithms is bucket_size /
# dp_size) apparently needs to be divisible by a power of 2 for high busbw.
bucket_size_divisor = math.lcm(self.data_parallel_world_size, 128, 2**16)
else:
bucket_size_divisor = math.lcm(self.data_parallel_world_size, 128)
return _pad(bucket_end_index, bucket_size_divisor)
return bucket_end_index return bucket_end_index
def _pad_start_of_param_if_needed(param_start_index: int) -> int: def _pad_start_of_param_if_needed(param_start_index: int) -> int:
...@@ -656,7 +697,10 @@ class _ParamAndGradBuffer: ...@@ -656,7 +697,10 @@ class _ParamAndGradBuffer:
numel = 0 numel = 0
for param in bucket.params: for param in bucket.params:
numel += param.data.nelement() numel += param.data.nelement()
log_strs.append(f'Params for bucket {index+1} ({numel} elements):') log_strs.append(
f"Params for bucket {index+1} ({numel} elements, "
f"{bucket.grad_data.nelement()} padded size):"
)
for param in bucket.params: for param in bucket.params:
log_strs.append(f'\t{param_to_name[param]}') log_strs.append(f'\t{param_to_name[param]}')
log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs)) log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs))
......
...@@ -12,12 +12,15 @@ try: ...@@ -12,12 +12,15 @@ try:
except ImportError: except ImportError:
HAVE_FSDP = False HAVE_FSDP = False
from megatron.core.fp8_utils import is_float8tensor
from .. import parallel_state, tensor_parallel from .. import parallel_state, tensor_parallel
from ..models.common.embeddings.language_model_embedding import LanguageModelEmbedding from ..models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from ..models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from ..models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from ..transformer.transformer_config import TransformerConfig from ..transformer.transformer_config import TransformerConfig
from ..transformer.transformer_layer import TransformerLayer from ..transformer.transformer_layer import TransformerLayer
from .data_parallel_base import _BaseDataParallel from .data_parallel_base import _BaseDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
class TorchFullyShardedDataParallel(_BaseDataParallel): class TorchFullyShardedDataParallel(_BaseDataParallel):
...@@ -29,6 +32,7 @@ class TorchFullyShardedDataParallel(_BaseDataParallel): ...@@ -29,6 +32,7 @@ class TorchFullyShardedDataParallel(_BaseDataParallel):
Args: Args:
config: Transformer config object. config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model. module: Underlying model.
sub_modules_to_wrap: List of sub_modules to shard with FSDP. sub_modules_to_wrap: List of sub_modules to shard with FSDP.
Parameters within each sub_module will be all-gathered just-in-time. Parameters within each sub_module will be all-gathered just-in-time.
...@@ -43,6 +47,7 @@ class TorchFullyShardedDataParallel(_BaseDataParallel): ...@@ -43,6 +47,7 @@ class TorchFullyShardedDataParallel(_BaseDataParallel):
def __init__( def __init__(
self, self,
config: TransformerConfig, config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module, module: torch.nn.Module,
sub_modules_to_wrap: List[torch.nn.Module] = [ sub_modules_to_wrap: List[torch.nn.Module] = [
TransformerLayer, TransformerLayer,
...@@ -50,7 +55,6 @@ class TorchFullyShardedDataParallel(_BaseDataParallel): ...@@ -50,7 +55,6 @@ class TorchFullyShardedDataParallel(_BaseDataParallel):
RotaryEmbedding, RotaryEmbedding,
tensor_parallel.ColumnParallelLinear, tensor_parallel.ColumnParallelLinear,
], ],
**kwargs
): ):
assert ( assert (
...@@ -62,14 +66,18 @@ class TorchFullyShardedDataParallel(_BaseDataParallel): ...@@ -62,14 +66,18 @@ class TorchFullyShardedDataParallel(_BaseDataParallel):
with_context_parallel=True with_context_parallel=True
) )
mesh = DeviceMesh.from_group(self.data_parallel_group, "cuda") kwargs = {"mesh": DeviceMesh.from_group(self.data_parallel_group, "cuda")}
kwargs = {"mesh": mesh}
def save_custom_attrs(module): def save_custom_attrs(module):
custom_attrs = {} custom_attrs = {}
for name, param in module.named_parameters(): for name, param in module.named_parameters():
attrs = vars(param) attrs = vars(param)
if is_float8tensor(param):
# disable fp8 transpose cache and perform transposing fp8 weights
# at each micro-batch because torch-FSDP doesn't recognize the
# micro-batch id, thus removing unnecessary memory stores
attrs['_fp8_attrs']['transpose_invalid'] = False
del attrs['_fp8_attrs']['transpose']
custom_attrs[name] = {k: v for k, v in attrs.items()} custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs return custom_attrs
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -21,6 +21,10 @@ DEFAULT_CONVERSION_DICT = { ...@@ -21,6 +21,10 @@ DEFAULT_CONVERSION_DICT = {
'decoder.layers.mlp.linear_fc1.bias': TRTLLMLayers.mlp_fc_bias, 'decoder.layers.mlp.linear_fc1.bias': TRTLLMLayers.mlp_fc_bias,
'decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight, 'decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight,
'decoder.layers.mlp.linear_fc2.bias': TRTLLMLayers.mlp_projection_bias, 'decoder.layers.mlp.linear_fc2.bias': TRTLLMLayers.mlp_projection_bias,
# EXPERTS
'decoder.layers.mlp.experts.experts.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight_mixture_of_experts,
'decoder.layers.mlp.experts.experts.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight_mixture_of_experts,
'decoder.layers.mlp.router.weight': TRTLLMLayers.mlp_router_weight,
# FINAL LAYER NORM # FINAL LAYER NORM
'decoder.final_layernorm.weight': TRTLLMLayers.final_layernorm_weight, 'decoder.final_layernorm.weight': TRTLLMLayers.final_layernorm_weight,
'decoder.final_layernorm.bias': TRTLLMLayers.final_layernorm_bias, 'decoder.final_layernorm.bias': TRTLLMLayers.final_layernorm_bias,
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment