Unverified Commit 5b185430 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[hotfix] fix zero ddp warmup check (#2545)

parent fa3d66fe
...@@ -58,6 +58,10 @@ class GeminiManager: ...@@ -58,6 +58,10 @@ class GeminiManager:
self._evict_time = 0 self._evict_time = 0
self._comp_cuda_demand_time = 0 self._comp_cuda_demand_time = 0
@property
def need_warmup(self) -> bool:
return self.policy_name in ('auto', 'const')
def is_warmup(self): def is_warmup(self):
return self._warmup return self._warmup
......
...@@ -269,7 +269,8 @@ class ZeroDDP(ColoDDP): ...@@ -269,7 +269,8 @@ class ZeroDDP(ColoDDP):
# check whether we are in a inference mode # check whether we are in a inference mode
grad_flag = torch.is_grad_enabled() grad_flag = torch.is_grad_enabled()
if not grad_flag: if not grad_flag:
assert not self.gemini_manager.is_warmup(), "You should run a completed iteration as your warmup iter" assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
), "You should run a completed iteration as your warmup iter"
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment