Unverified Commit 935ba78b authored by Zhongyu Li's avatar Zhongyu Li Committed by GitHub
Browse files

[Fix] Fix skip_layer for RF-Next (#2489)

* judge skip_layer by fullname

* lint

* skip_layer first

* update unit test
parent 30d975a5
...@@ -143,7 +143,10 @@ class RFSearchHook(Hook): ...@@ -143,7 +143,10 @@ class RFSearchHook(Hook):
module.estimate_rates() module.estimate_rates()
module.expand_rates() module.expand_rates()
def wrap_model(self, model: nn.Module, search_op: str = 'Conv2d'): def wrap_model(self,
model: nn.Module,
search_op: str = 'Conv2d',
prefix: str = ''):
"""wrap model to support searchable conv op. """wrap model to support searchable conv op.
Args: Args:
...@@ -152,9 +155,18 @@ class RFSearchHook(Hook): ...@@ -152,9 +155,18 @@ class RFSearchHook(Hook):
Defaults to 'Conv2d'. Defaults to 'Conv2d'.
init_rates (int, optional): Set to other initial dilation rates. init_rates (int, optional): Set to other initial dilation rates.
Defaults to None. Defaults to None.
prefix (str): Prefix for function recursion. Defaults to ''.
""" """
op = 'torch.nn.' + search_op op = 'torch.nn.' + search_op
for name, module in model.named_children(): for name, module in model.named_children():
if prefix == '':
fullname = 'module.' + name
else:
fullname = prefix + '.' + name
if self.config['search']['skip_layer'] is not None:
if any(layer in fullname
for layer in self.config['search']['skip_layer']):
continue
if isinstance(module, eval(op)): if isinstance(module, eval(op)):
if 1 < module.kernel_size[0] and \ if 1 < module.kernel_size[0] and \
0 != module.kernel_size[0] % 2 or \ 0 != module.kernel_size[0] % 2 or \
...@@ -167,14 +179,8 @@ class RFSearchHook(Hook): ...@@ -167,14 +179,8 @@ class RFSearchHook(Hook):
logger.info('Wrap model %s to %s.' % logger.info('Wrap model %s to %s.' %
(str(module), str(moduleWrap))) (str(module), str(moduleWrap)))
setattr(model, name, moduleWrap) setattr(model, name, moduleWrap)
elif isinstance(module, BaseConvRFSearchOp): elif not isinstance(module, BaseConvRFSearchOp):
pass self.wrap_model(module, search_op, fullname)
else:
if self.config['search']['skip_layer'] is not None:
if any(layer in name
for layer in self.config['search']['skip_layer']):
continue
self.wrap_model(module, search_op)
def set_model(self, def set_model(self,
model: nn.Module, model: nn.Module,
...@@ -198,6 +204,10 @@ class RFSearchHook(Hook): ...@@ -198,6 +204,10 @@ class RFSearchHook(Hook):
fullname = 'module.' + name fullname = 'module.' + name
else: else:
fullname = prefix + '.' + name fullname = prefix + '.' + name
if self.config['search']['skip_layer'] is not None:
if any(layer in fullname
for layer in self.config['search']['skip_layer']):
continue
if isinstance(module, eval(op)): if isinstance(module, eval(op)):
if 1 < module.kernel_size[0] and \ if 1 < module.kernel_size[0] and \
0 != module.kernel_size[0] % 2 or \ 0 != module.kernel_size[0] % 2 or \
...@@ -224,11 +234,5 @@ class RFSearchHook(Hook): ...@@ -224,11 +234,5 @@ class RFSearchHook(Hook):
logger.info( logger.info(
'Set module %s dilation as: [%d %d]' % 'Set module %s dilation as: [%d %d]' %
(fullname, module.dilation[0], module.dilation[1])) (fullname, module.dilation[0], module.dilation[1]))
elif isinstance(module, BaseConvRFSearchOp): elif not isinstance(module, BaseConvRFSearchOp):
pass
else:
if self.config['search']['skip_layer'] is not None:
if any(layer in fullname
for layer in self.config['search']['skip_layer']):
continue
self.set_model(module, search_op, init_rates, fullname) self.set_model(module, search_op, init_rates, fullname)
...@@ -16,36 +16,36 @@ from tests.test_runner.test_hooks import _build_demo_runner ...@@ -16,36 +16,36 @@ from tests.test_runner.test_hooks import _build_demo_runner
def test_rfsearchhook(): def test_rfsearchhook():
def conv(in_channels, out_channels, kernel_size, stride, padding,
dilation):
return nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv1 = nn.Conv2d( self.stem = conv(1, 2, 3, 1, 1, 1)
in_channels=1, self.conv0 = conv(2, 2, 3, 1, 1, 1)
out_channels=2, self.layer0 = nn.Sequential(
kernel_size=1, conv(2, 2, 3, 1, 1, 1), conv(2, 2, 3, 1, 1, 1))
stride=1, self.conv1 = conv(2, 2, 1, 1, 0, 1)
padding=0, self.conv2 = conv(2, 2, 3, 1, 1, 1)
dilation=1) self.conv3 = conv(2, 2, (1, 3), 1, (0, 1), 1)
self.conv2 = nn.Conv2d(
in_channels=2,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
self.conv3 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=(1, 3),
stride=1,
padding=(0, 1),
dilation=1)
def forward(self, x): def forward(self, x):
x1 = self.conv1(x) x1 = self.stem(x)
x2 = self.conv2(x1) x2 = self.layer0(x1)
return x2 x3 = self.conv0(x2)
x4 = self.conv1(x3)
x5 = self.conv2(x4)
x6 = self.conv3(x5)
return x6
def train_step(self, x, optimizer, **kwargs): def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x).mean(), num_samples=x.shape[0]) return dict(loss=self(x).mean(), num_samples=x.shape[0])
...@@ -63,13 +63,14 @@ def test_rfsearchhook(): ...@@ -63,13 +63,14 @@ def test_rfsearchhook():
mmin=1, mmin=1,
mmax=24, mmax=24,
num_branches=2, num_branches=2,
skip_layer=['stem', 'layer1'])), skip_layer=['stem', 'conv0', 'layer0.1'])),
) )
# hook for search # hook for search
rfsearchhook_search = RFSearchHook( rfsearchhook_search = RFSearchHook(
'search', rfsearch_cfg['config'], by_epoch=True, verbose=True) 'search', rfsearch_cfg['config'], by_epoch=True, verbose=True)
rfsearchhook_search.config['structure'] = { rfsearchhook_search.config['structure'] = {
'module.layer0.0': [1, 1],
'module.conv2': [2, 2], 'module.conv2': [2, 2],
'module.conv3': [1, 1] 'module.conv3': [1, 1]
} }
...@@ -80,6 +81,7 @@ def test_rfsearchhook(): ...@@ -80,6 +81,7 @@ def test_rfsearchhook():
by_epoch=True, by_epoch=True,
verbose=True) verbose=True)
rfsearchhook_fixed_single_branch.config['structure'] = { rfsearchhook_fixed_single_branch.config['structure'] = {
'module.layer0.0': [1, 1],
'module.conv2': [2, 2], 'module.conv2': [2, 2],
'module.conv3': [1, 1] 'module.conv3': [1, 1]
} }
...@@ -90,14 +92,22 @@ def test_rfsearchhook(): ...@@ -90,14 +92,22 @@ def test_rfsearchhook():
by_epoch=True, by_epoch=True,
verbose=True) verbose=True)
rfsearchhook_fixed_multi_branch.config['structure'] = { rfsearchhook_fixed_multi_branch.config['structure'] = {
'module.layer0.0': [1, 1],
'module.conv2': [2, 2], 'module.conv2': [2, 2],
'module.conv3': [1, 1] 'module.conv3': [1, 1]
} }
def test_skip_layer():
assert not isinstance(model.stem, Conv2dRFSearchOp)
assert not isinstance(model.conv0, Conv2dRFSearchOp)
assert isinstance(model.layer0[0], Conv2dRFSearchOp)
assert not isinstance(model.layer0[1], Conv2dRFSearchOp)
# 1. test init_model() with mode of search # 1. test init_model() with mode of search
model = Model() model = Model()
rfsearchhook_search.init_model(model) rfsearchhook_search.init_model(model)
test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp) assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp) assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp) assert isinstance(model.conv3, Conv2dRFSearchOp)
...@@ -111,6 +121,7 @@ def test_rfsearchhook(): ...@@ -111,6 +121,7 @@ def test_rfsearchhook():
runner.register_hook(rfsearchhook_search) runner.register_hook(rfsearchhook_search)
runner.run([loader], [('train', 1)]) runner.run([loader], [('train', 1)])
test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp) assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp) assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp) assert isinstance(model.conv3, Conv2dRFSearchOp)
...@@ -145,6 +156,7 @@ def test_rfsearchhook(): ...@@ -145,6 +156,7 @@ def test_rfsearchhook():
model = Model() model = Model()
rfsearchhook_fixed_multi_branch.init_model(model) rfsearchhook_fixed_multi_branch.init_model(model)
test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp) assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp) assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp) assert isinstance(model.conv3, Conv2dRFSearchOp)
...@@ -157,6 +169,7 @@ def test_rfsearchhook(): ...@@ -157,6 +169,7 @@ def test_rfsearchhook():
runner.register_hook(rfsearchhook_fixed_multi_branch) runner.register_hook(rfsearchhook_fixed_multi_branch)
runner.run([loader], [('train', 1)]) runner.run([loader], [('train', 1)])
test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp) assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp) assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp) assert isinstance(model.conv3, Conv2dRFSearchOp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment