Commit 0d99ae1f authored by silencealiang's avatar silencealiang
Browse files

add

parent c271aaae
Pipeline #2498 canceled with stages
File mode changed from 100755 to 100644
...@@ -104,8 +104,6 @@ def load( ...@@ -104,8 +104,6 @@ def load(
checkpoint_dir = Path(checkpoint_dir) checkpoint_dir = Path(checkpoint_dir)
common_state_dict = common_strategy.load_common(checkpoint_dir) common_state_dict = common_strategy.load_common(checkpoint_dir)
if not sharded_state_dict:
return common_state_dict
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess( sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict sharded_state_dict
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -412,7 +412,7 @@ def validate_sharding_integrity( ...@@ -412,7 +412,7 @@ def validate_sharding_integrity(
CheckpointingException for invalid access pattern CheckpointingException for invalid access pattern
""" """
if common_state_dict: if common_state_dict is not None:
_validate_common_state_dict(common_state_dict) _validate_common_state_dict(common_state_dict)
if torch.distributed.get_rank() != 0: if torch.distributed.get_rank() != 0:
...@@ -461,10 +461,15 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]): ...@@ -461,10 +461,15 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
lambda x: x[1], lambda x: x[1],
_validate_sharding_for_key_flattened, _validate_sharding_for_key_flattened,
) )
else: # For each shard with at least 1 flattened tensor in it, the above
if not torch.all(shard_access_cnt == 1): # `_validate_sharding_for_key_flattened` ensure a correct consistent pattern
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}') # The only thing that can go wrong at this point is that some shard don't have
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}') # *any* representatives which will be checked later by comparing `shard_access_cnt == 1`
shard_access_cnt = torch.minimum(shard_access_cnt, torch.tensor([1]))
if not torch.all(shard_access_cnt == 1):
raise CheckpointingException(
f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}'
)
def _compute_shards_access(rank_sharding): def _compute_shards_access(rank_sharding):
...@@ -489,16 +494,10 @@ def _validate_sharding_for_key_flattened(tensors_by_shard): ...@@ -489,16 +494,10 @@ def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop)) all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices))) starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if ( expected_size = np.product(local_shape)
starts[0] != 0 if starts[0] != 0 or stops[-1] != expected_size or not np.all(starts[1:] == stops[:-1]):
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
raise CheckpointingException( raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}' f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]} of size {expected_size}. Ranges: {(starts, stops)}'
) )
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment