Unverified Commit 0f7ed8c1 authored by ver217's avatar ver217 Committed by GitHub
Browse files

fix _post_init_method of zero init ctx (#847)

parent 2a0a427e
......@@ -155,7 +155,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
torch.set_rng_state(self.cpu_rng_state)
torch.cuda.set_rng_state(self.cuda_rng_state)
def _post_init_method(self, module: torch.nn.Module):
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment