Unverified Commit a1a7899c authored by ver217's avatar ver217 Committed by GitHub
Browse files

[hotfix] fix zero init ctx numel (#1128)

parent f0a954f1
...@@ -78,6 +78,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): ...@@ -78,6 +78,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
ZeroContextMgr().current_context = self ZeroContextMgr().current_context = self
self.param_numel = {}
self.top_module = None
@property @property
def target_device(self): def target_device(self):
return self.config.target_device return self.config.target_device
...@@ -169,11 +172,18 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): ...@@ -169,11 +172,18 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
torch.set_rng_state(self.cpu_rng_state) torch.set_rng_state(self.cpu_rng_state)
torch.cuda.set_rng_state(self.cuda_rng_state) torch.cuda.set_rng_state(self.cuda_rng_state)
params = frozenset(self.top_module.parameters())
for param in self.param_numel.keys():
if param not in params:
self.param_numel[param] = 0
self.model_numel_tensor.fill_(sum(self.param_numel.values()))
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
""" """
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times. NOTE() The module may be passed to this function multiple times.
""" """
self.top_module = module
def half_fn(t: torch.Tensor): def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t return t.half() if t.is_floating_point() else t
...@@ -183,7 +193,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): ...@@ -183,7 +193,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if hasattr(param, 'colo_attr'): if hasattr(param, 'colo_attr'):
continue continue
self.model_numel_tensor += param.numel() self.param_numel[param] = param.numel()
# convert parameters to half # convert parameters to half
param_half = half_fn(param) param_half = half_fn(param)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment