Unverified Commit 6690a61b authored by ver217's avatar ver217 Committed by GitHub
Browse files

[hotfix] prevent nested ZeRO (#1140)

parent 15aab147
......@@ -77,6 +77,7 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
super().__init__()
self.logger = get_dist_logger()
......
......@@ -87,6 +87,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
mp_process_group: Optional[ProcessGroup] = None,
verbose: bool = False) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.'
super().__init__(optimizer)
self.shard_strategy = sharded_model.shard_strategy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment