Unverified Commit a1a7899c authored by ver217's avatar ver217 Committed by GitHub
Browse files

[hotfix] fix zero init ctx numel (#1128)

parent f0a954f1
......@@ -78,6 +78,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
ZeroContextMgr().current_context = self
self.param_numel = {}
self.top_module = None
@property
def target_device(self):
return self.config.target_device
......@@ -169,11 +172,18 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
torch.set_rng_state(self.cpu_rng_state)
torch.cuda.set_rng_state(self.cuda_rng_state)
params = frozenset(self.top_module.parameters())
for param in self.param_numel.keys():
if param not in params:
self.param_numel[param] = 0
self.model_numel_tensor.fill_(sum(self.param_numel.values()))
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
self.top_module = module
def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t
......@@ -183,7 +193,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if hasattr(param, 'colo_attr'):
continue
self.model_numel_tensor += param.numel()
self.param_numel[param] = param.numel()
# convert parameters to half
param_half = half_fn(param)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment