Unverified Commit 6690a61b authored by ver217's avatar ver217 Committed by GitHub
Browse files

[hotfix] prevent nested ZeRO (#1140)

parent 15aab147
...@@ -77,6 +77,7 @@ class ShardedModelV2(nn.Module): ...@@ -77,6 +77,7 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy: str = 'cuda', tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0, gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False): reuse_fp16_shard: bool = False):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
super().__init__() super().__init__()
self.logger = get_dist_logger() self.logger = get_dist_logger()
......
...@@ -87,6 +87,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): ...@@ -87,6 +87,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
mp_process_group: Optional[ProcessGroup] = None, mp_process_group: Optional[ProcessGroup] = None,
verbose: bool = False) -> None: verbose: bool = False) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.'
super().__init__(optimizer) super().__init__(optimizer)
self.shard_strategy = sharded_model.shard_strategy self.shard_strategy = sharded_model.shard_strategy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment