You need to sign in or sign up before continuing.
Unverified Commit de4f14e9 authored by ZhangShilong's avatar ZhangShilong Committed by GitHub
Browse files

[Enhancement]: refactor init cfg (#958)

parent 5504b5fb
...@@ -22,8 +22,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -22,8 +22,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
# define default value of init_cfg instead of hard code # define default value of init_cfg instead of hard code
# in init_weigt() function # in init_weigt() function
self._is_init = False self._is_init = False
if init_cfg is not None: self.init_cfg = init_cfg
self.init_cfg = init_cfg
# Backward compatibility in derived classes # Backward compatibility in derived classes
# if pretrained is not None: # if pretrained is not None:
...@@ -40,7 +39,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -40,7 +39,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
from ..cnn import initialize from ..cnn import initialize
if not self._is_init: if not self._is_init:
if hasattr(self, 'init_cfg'): if self.init_cfg:
initialize(self, self.init_cfg) initialize(self, self.init_cfg)
for m in self.children(): for m in self.children():
if hasattr(m, 'init_weight'): if hasattr(m, 'init_weight'):
...@@ -52,7 +51,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta): ...@@ -52,7 +51,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
def __repr__(self): def __repr__(self):
s = super().__repr__() s = super().__repr__()
if hasattr(self, 'init_cfg'): if self.init_cfg:
s += f'\ninit_cfg={self.init_cfg}' s += f'\ninit_cfg={self.init_cfg}'
return s return s
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment