Commit 0d99ae1f authored by silencealiang's avatar silencealiang
Browse files

add

parent c271aaae
Pipeline #2498 canceled with stages
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from .. import parallel_state from .. import parallel_state
from ..config_logger import has_config_logger_enabled, log_config_to_disk from ..config_logger import has_config_logger_enabled, log_config_to_disk
from ..transformer.cuda_graphs import is_graph_capturing
from ..transformer.transformer_config import TransformerConfig from ..transformer.transformer_config import TransformerConfig
from ..utils import is_float8tensor, log_single_rank from ..utils import is_float8tensor, log_single_rank
from .data_parallel_base import _BaseDataParallel from .data_parallel_base import _BaseDataParallel
...@@ -151,12 +152,20 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -151,12 +152,20 @@ class DistributedDataParallel(_BaseDataParallel):
with_context_parallel=True with_context_parallel=True
) )
if self.ddp_config.average_in_collective: if self.ddp_config.average_in_collective:
# Collective is averaging gradients in collective with data_parallel_group. if self.ddp_config.num_distributed_optimizer_instances == 1:
assert ( # Collective is averaging gradients in collective with data_parallel_group.
gradient_scaling_factor assert (
/ parallel_state.get_data_parallel_world_size(with_context_parallel=True) gradient_scaling_factor
== target_gradient_scaling_factor / torch.distributed.get_world_size(group=data_parallel_group)
) == target_gradient_scaling_factor
)
else:
# For non-expert parameters, gradient_scaling_factor is 1.
# For expert parameters, gradient_scaling_factor is 1/ep_size.
assert (gradient_scaling_factor == 1) or (
gradient_scaling_factor
== (1.0 / parallel_state.get_expert_model_parallel_world_size())
)
else: else:
assert gradient_scaling_factor == target_gradient_scaling_factor assert gradient_scaling_factor == target_gradient_scaling_factor
...@@ -297,9 +306,10 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -297,9 +306,10 @@ class DistributedDataParallel(_BaseDataParallel):
self._make_forward_pre_hook() self._make_forward_pre_hook()
) )
def disable_forward_pre_hook(self): def disable_forward_pre_hook(self, param_sync: bool = True):
""" """
Disable forward pre-hooks needed for param all-gather overlap with forward compute. Disable forward pre-hooks needed for param all-gather overlap with forward compute.
Skip synchronous param all-gather if `param_sync` is False.
""" """
assert self.use_forward_hook assert self.use_forward_hook
# De-register forward pre-hook for all sub-modules. # De-register forward pre-hook for all sub-modules.
...@@ -310,7 +320,8 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -310,7 +320,8 @@ class DistributedDataParallel(_BaseDataParallel):
assert len(self.remove_forward_pre_hook_handles) == 0 assert len(self.remove_forward_pre_hook_handles) == 0
# Force synchronize parameters. # Force synchronize parameters.
self.start_param_sync(force_sync=True) if param_sync:
self.start_param_sync(force_sync=True)
def _make_forward_pre_hook(self): def _make_forward_pre_hook(self):
""" """
...@@ -323,6 +334,9 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -323,6 +334,9 @@ class DistributedDataParallel(_BaseDataParallel):
self.use_forward_hook self.use_forward_hook
), "Should use pre-hook only when overlap_param_gather is True" ), "Should use pre-hook only when overlap_param_gather is True"
if is_graph_capturing():
return
# Make sure all parameters in this module have been all-gathered as necessary. # Make sure all parameters in this module have been all-gathered as necessary.
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
# Skip parameters without an associated buffer (such parameters have a # Skip parameters without an associated buffer (such parameters have a
...@@ -353,6 +367,9 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -353,6 +367,9 @@ class DistributedDataParallel(_BaseDataParallel):
""" """
def hook(*unused): def hook(*unused):
if is_graph_capturing():
return
if param in self.param_to_bucket_group: if param in self.param_to_bucket_group:
assert param.requires_grad assert param.requires_grad
if self.ddp_config.overlap_grad_reduce: if self.ddp_config.overlap_grad_reduce:
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -270,13 +270,12 @@ class _ParamAndGradBucketGroup: ...@@ -270,13 +270,12 @@ class _ParamAndGradBucketGroup:
if self.ddp_config.average_in_collective: if self.ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG reduce_op = torch.distributed.ReduceOp.AVG
# Stream synchronization logic of the CUDA streams that is # We use the following stream synchronization for the gradient reduction
# implemented below for the gradient reduction within and across # within and across DistOpt instances.
# distributed optimizer instances.
# Compute Stream - -------------Gradient Compute------------------- # Compute Stream: -------------Gradient compute-------------------
# Comm. Stream - ------(wait for nccl)-----(wait for nccl)------- # Comm. Stream: ------(wait for NCCL)-----(wait for NCCL)-------
# NCCL Stream - -------RS------ -------AR------ # NCCL Stream: -------RS------ -------AR------
# Use async communications only when overlap_grad_reduce is True. # Use async communications only when overlap_grad_reduce is True.
async_op = ( async_op = (
...@@ -287,13 +286,13 @@ class _ParamAndGradBucketGroup: ...@@ -287,13 +286,13 @@ class _ParamAndGradBucketGroup:
self.ddp_config.num_distributed_optimizer_instances > 1 self.ddp_config.num_distributed_optimizer_instances > 1
and self.ddp_config.overlap_grad_reduce and self.ddp_config.overlap_grad_reduce
): ):
# Assign a communication stream if we use partial DP DistOpt and we # Assign a communication stream if we have multiple DistOpt instances and we
# need to overlap communication # need to overlap communication.
stream_context = torch.cuda.stream(self.communication_stream) stream_context = torch.cuda.stream(self.communication_stream)
# The RS/AR communication stream needs to wait for the default stream # The RS/AR communication stream needs to wait for the default stream
# to complete its gradient computation before launching the next # to complete its gradient computation before launching the next
# gradient reduction collective # gradient reduction collective.
self.communication_stream.wait_stream(torch.cuda.default_stream()) self.communication_stream.wait_stream(torch.cuda.default_stream())
else: else:
stream_context = nullcontext() stream_context = nullcontext()
...@@ -314,24 +313,21 @@ class _ParamAndGradBucketGroup: ...@@ -314,24 +313,21 @@ class _ParamAndGradBucketGroup:
local_data_view, local_data_view,
bucket.grad_data, bucket.grad_data,
op=reduce_op, op=reduce_op,
group=self.intra_distributed_optimizer_instance_group, group=communication_group,
async_op=async_op, async_op=async_op,
) )
else: else:
torch.distributed.all_reduce( torch.distributed.all_reduce(
bucket.grad_data, bucket.grad_data, op=reduce_op, group=communication_group, async_op=async_op
op=reduce_op,
group=self.data_parallel_group,
async_op=async_op,
) )
# When enabling partial DP domain DistOpt, we need to All-Reduce across all partial domains # With multiple DistOpt instances, we need to all-reduce across instances.
if ( if (
self.ddp_config.use_distributed_optimizer self.ddp_config.use_distributed_optimizer
and self.ddp_config.num_distributed_optimizer_instances > 1 and self.ddp_config.num_distributed_optimizer_instances > 1
): ):
# Create a new coalescing facility for the inter partial DP-AllReduce here # Create a new coalescing manager for the inter-instance all-reduce.
with stream_context, _coalescing_manager( with stream_context, _coalescing_manager(
self.inter_distributed_optimizer_instance_group, async_ops=async_op self.inter_distributed_optimizer_instance_group, async_ops=async_op
) as cm: ) as cm:
...@@ -366,13 +362,13 @@ class _ParamAndGradBucketGroup: ...@@ -366,13 +362,13 @@ class _ParamAndGradBucketGroup:
communication call to complete. When ddp_config.overlap_grad_reduce is set to False, communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
makes synchronous call. makes synchronous call.
""" """
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
self.param_gather_dispatched = False self.param_gather_dispatched = False
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
if not self.ddp_config.overlap_grad_reduce: if not self.ddp_config.overlap_grad_reduce:
self.start_grad_sync() self.start_grad_sync()
return return
# When using partial DP DistOpt, we don't need to sync as we launch comms on a separate # When using multiple DistOpt instances, we don't need to sync here as we launch
# communication stream # communications on a separate communication stream.
if self.ddp_config.num_distributed_optimizer_instances > 1: if self.ddp_config.num_distributed_optimizer_instances > 1:
torch.cuda.default_stream().wait_stream(self.communication_stream) torch.cuda.default_stream().wait_stream(self.communication_stream)
return return
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment