Commit 99a0c39e authored by xingjinliang's avatar xingjinliang
Browse files

同步最新代码

parent 50fe58fa
...@@ -104,8 +104,6 @@ def load( ...@@ -104,8 +104,6 @@ def load(
checkpoint_dir = Path(checkpoint_dir) checkpoint_dir = Path(checkpoint_dir)
common_state_dict = common_strategy.load_common(checkpoint_dir) common_state_dict = common_strategy.load_common(checkpoint_dir)
if not sharded_state_dict:
return common_state_dict
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess( sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict sharded_state_dict
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -412,7 +412,7 @@ def validate_sharding_integrity( ...@@ -412,7 +412,7 @@ def validate_sharding_integrity(
CheckpointingException for invalid access pattern CheckpointingException for invalid access pattern
""" """
if common_state_dict: if common_state_dict is not None:
_validate_common_state_dict(common_state_dict) _validate_common_state_dict(common_state_dict)
if torch.distributed.get_rank() != 0: if torch.distributed.get_rank() != 0:
...@@ -461,10 +461,15 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]): ...@@ -461,10 +461,15 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
lambda x: x[1], lambda x: x[1],
_validate_sharding_for_key_flattened, _validate_sharding_for_key_flattened,
) )
else: # For each shard with at least 1 flattened tensor in it, the above
if not torch.all(shard_access_cnt == 1): # `_validate_sharding_for_key_flattened` ensure a correct consistent pattern
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}') # The only thing that can go wrong at this point is that some shard don't have
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}') # *any* representatives which will be checked later by comparing `shard_access_cnt == 1`
shard_access_cnt = torch.minimum(shard_access_cnt, torch.tensor([1]))
if not torch.all(shard_access_cnt == 1):
raise CheckpointingException(
f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}'
)
def _compute_shards_access(rank_sharding): def _compute_shards_access(rank_sharding):
...@@ -489,16 +494,10 @@ def _validate_sharding_for_key_flattened(tensors_by_shard): ...@@ -489,16 +494,10 @@ def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop)) all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices))) starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if ( expected_size = np.product(local_shape)
starts[0] != 0 if starts[0] != 0 or stops[-1] != expected_size or not np.all(starts[1:] == stops[:-1]):
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
raise CheckpointingException( raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}' f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]} of size {expected_size}. Ranges: {(starts, stops)}'
) )
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from .. import parallel_state from .. import parallel_state
from ..config_logger import has_config_logger_enabled, log_config_to_disk from ..config_logger import has_config_logger_enabled, log_config_to_disk
from ..transformer.cuda_graphs import is_graph_capturing
from ..transformer.transformer_config import TransformerConfig from ..transformer.transformer_config import TransformerConfig
from ..utils import is_float8tensor, log_single_rank from ..utils import is_float8tensor, log_single_rank
from .data_parallel_base import _BaseDataParallel from .data_parallel_base import _BaseDataParallel
...@@ -151,12 +152,20 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -151,12 +152,20 @@ class DistributedDataParallel(_BaseDataParallel):
with_context_parallel=True with_context_parallel=True
) )
if self.ddp_config.average_in_collective: if self.ddp_config.average_in_collective:
# Collective is averaging gradients in collective with data_parallel_group. if self.ddp_config.num_distributed_optimizer_instances == 1:
assert ( # Collective is averaging gradients in collective with data_parallel_group.
gradient_scaling_factor assert (
/ parallel_state.get_data_parallel_world_size(with_context_parallel=True) gradient_scaling_factor
== target_gradient_scaling_factor / torch.distributed.get_world_size(group=data_parallel_group)
) == target_gradient_scaling_factor
)
else:
# For non-expert parameters, gradient_scaling_factor is 1.
# For expert parameters, gradient_scaling_factor is 1/ep_size.
assert (gradient_scaling_factor == 1) or (
gradient_scaling_factor
== (1.0 / parallel_state.get_expert_model_parallel_world_size())
)
else: else:
assert gradient_scaling_factor == target_gradient_scaling_factor assert gradient_scaling_factor == target_gradient_scaling_factor
...@@ -297,9 +306,10 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -297,9 +306,10 @@ class DistributedDataParallel(_BaseDataParallel):
self._make_forward_pre_hook() self._make_forward_pre_hook()
) )
def disable_forward_pre_hook(self): def disable_forward_pre_hook(self, param_sync: bool = True):
""" """
Disable forward pre-hooks needed for param all-gather overlap with forward compute. Disable forward pre-hooks needed for param all-gather overlap with forward compute.
Skip synchronous param all-gather if `param_sync` is False.
""" """
assert self.use_forward_hook assert self.use_forward_hook
# De-register forward pre-hook for all sub-modules. # De-register forward pre-hook for all sub-modules.
...@@ -310,7 +320,8 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -310,7 +320,8 @@ class DistributedDataParallel(_BaseDataParallel):
assert len(self.remove_forward_pre_hook_handles) == 0 assert len(self.remove_forward_pre_hook_handles) == 0
# Force synchronize parameters. # Force synchronize parameters.
self.start_param_sync(force_sync=True) if param_sync:
self.start_param_sync(force_sync=True)
def _make_forward_pre_hook(self): def _make_forward_pre_hook(self):
""" """
...@@ -323,6 +334,9 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -323,6 +334,9 @@ class DistributedDataParallel(_BaseDataParallel):
self.use_forward_hook self.use_forward_hook
), "Should use pre-hook only when overlap_param_gather is True" ), "Should use pre-hook only when overlap_param_gather is True"
if is_graph_capturing():
return
# Make sure all parameters in this module have been all-gathered as necessary. # Make sure all parameters in this module have been all-gathered as necessary.
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
# Skip parameters without an associated buffer (such parameters have a # Skip parameters without an associated buffer (such parameters have a
...@@ -353,6 +367,9 @@ class DistributedDataParallel(_BaseDataParallel): ...@@ -353,6 +367,9 @@ class DistributedDataParallel(_BaseDataParallel):
""" """
def hook(*unused): def hook(*unused):
if is_graph_capturing():
return
if param in self.param_to_bucket_group: if param in self.param_to_bucket_group:
assert param.requires_grad assert param.requires_grad
if self.ddp_config.overlap_grad_reduce: if self.ddp_config.overlap_grad_reduce:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment