Commit c4588297 authored by Mashiro's avatar Mashiro Committed by Zaida Zhou
Browse files

Refine rfsearch and fix a typo

parent 1f9e5b57
...@@ -40,6 +40,7 @@ Module ...@@ -40,6 +40,7 @@ Module
NonLocal3d NonLocal3d
Scale Scale
Swish Swish
Conv2dRFSearchOp
Build Function Build Function
---------------- ----------------
......
...@@ -40,6 +40,7 @@ Module ...@@ -40,6 +40,7 @@ Module
NonLocal3d NonLocal3d
Scale Scale
Swish Swish
Conv2dRFSearchOp
Build Function Build Function
---------------- ----------------
......
...@@ -4,14 +4,12 @@ import copy ...@@ -4,14 +4,12 @@ import copy
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.logging import MMLogger from mmengine.logging import print_log
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import Tensor from torch import Tensor
from .utils import expand_rates, get_single_padding from .utils import expand_rates, get_single_padding
logger = MMLogger.get_current_instance()
class BaseConvRFSearchOp(BaseModule): class BaseConvRFSearchOp(BaseModule):
"""Based class of ConvRFSearchOp. """Based class of ConvRFSearchOp.
...@@ -84,7 +82,7 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp): ...@@ -84,7 +82,7 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches)) self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches))
if self.verbose: if self.verbose:
logger.info(f'Expand as {self.dilation_rates}') print_log(f'Expand as {self.dilation_rates}', 'current')
nn.init.constant_(self.branch_weights, global_config['init_alphas']) nn.init.constant_(self.branch_weights, global_config['init_alphas'])
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
...@@ -118,13 +116,14 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp): ...@@ -118,13 +116,14 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
output += outputs[i] output += outputs[i]
return output return output
def estimate_rates(self): def estimate_rates(self) -> None:
"""Estimate new dilation rate based on trained branch_weights.""" """Estimate new dilation rate based on trained branch_weights."""
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)]) norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
if self.verbose: if self.verbose:
logger.info('Estimate dilation {} with weight {}.'.format( print_log(
self.dilation_rates, 'Estimate dilation {} with weight {}.'.format(
norm_w.detach().cpu().numpy().tolist())) self.dilation_rates,
norm_w.detach().cpu().numpy().tolist()), 'current')
sum0, sum1, w_sum = 0, 0, 0 sum0, sum1, w_sum = 0, 0, 0
for i in range(len(self.dilation_rates)): for i in range(len(self.dilation_rates)):
...@@ -143,9 +142,9 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp): ...@@ -143,9 +142,9 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
self.op_layer.padding = self.get_padding(self.op_layer.dilation) self.op_layer.padding = self.get_padding(self.op_layer.dilation)
self.dilation_rates = [tuple(estimated)] self.dilation_rates = [tuple(estimated)]
if self.verbose: if self.verbose:
logger.info(f'Estimate as {tuple(estimated)}') print_log(f'Estimate as {tuple(estimated)}', 'current')
def expand_rates(self): def expand_rates(self) -> None:
"""Expand dilation rate.""" """Expand dilation rate."""
dilation = self.op_layer.dilation dilation = self.op_layer.dilation
dilation_rates = expand_rates(dilation, self.global_config) dilation_rates = expand_rates(dilation, self.global_config)
...@@ -158,11 +157,11 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp): ...@@ -158,11 +157,11 @@ class Conv2dRFSearchOp(BaseConvRFSearchOp):
self.dilation_rates = copy.deepcopy(dilation_rates) self.dilation_rates = copy.deepcopy(dilation_rates)
if self.verbose: if self.verbose:
logger.info(f'Expand as {self.dilation_rates}') print_log(f'Expand as {self.dilation_rates}', 'current')
nn.init.constant_(self.branch_weights, nn.init.constant_(self.branch_weights,
self.global_config['init_alphas']) self.global_config['init_alphas'])
def get_padding(self, dilation): def get_padding(self, dilation) -> tuple:
padding = (get_single_padding(self.op_layer.kernel_size[0], padding = (get_single_padding(self.op_layer.kernel_size[0],
self.op_layer.stride[0], dilation[0]), self.op_layer.stride[0], dilation[0]),
get_single_padding(self.op_layer.kernel_size[1], get_single_padding(self.op_layer.kernel_size[1],
......
...@@ -3,15 +3,14 @@ import os ...@@ -3,15 +3,14 @@ import os
from typing import Dict, Optional from typing import Dict, Optional
import mmengine import mmengine
import torch # noqa
import torch.nn as nn import torch.nn as nn
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.logging import MMLogger from mmengine.logging import print_log
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from mmcv.cnn.rfsearch.utils import get_single_padding, write_to_json from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp # noqa
from .operator import BaseConvRFSearchOp from .utils import get_single_padding, write_to_json
logger = MMLogger.get_current_instance()
@HOOKS.register_module() @HOOKS.register_module()
...@@ -82,7 +81,7 @@ class RFSearchHook(Hook): ...@@ -82,7 +81,7 @@ class RFSearchHook(Hook):
search/fixed_single_branch/fixed_multi_branch search/fixed_single_branch/fixed_multi_branch
""" """
if self.verbose: if self.verbose:
logger.info('RFSearch init begin.') print_log('RFSearch init begin.', 'current')
if self.mode == 'search': if self.mode == 'search':
if self.config['structure']: if self.config['structure']:
self.set_model(model, search_op='Conv2d') self.set_model(model, search_op='Conv2d')
...@@ -95,19 +94,19 @@ class RFSearchHook(Hook): ...@@ -95,19 +94,19 @@ class RFSearchHook(Hook):
else: else:
raise NotImplementedError raise NotImplementedError
if self.verbose: if self.verbose:
logger.info('RFSearch init end.') print_log('RFSearch init end.', 'current')
def after_train_epoch(self, runner): def after_train_epoch(self, runner):
"""Performs a dilation searching step after one training epoch.""" """Performs a dilation searching step after one training epoch."""
if self.by_epoch and self.mode == 'search': if self.by_epoch and self.mode == 'search':
self.step(runner.model, runner.work_dir) self.step(runner.model, runner.work_dir)
def after_train_iter(self, runner): def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""Performs a dilation searching step after one training iteration.""" """Performs a dilation searching step after one training iteration."""
if not self.by_epoch and self.mode == 'search': if not self.by_epoch and self.mode == 'search':
self.step(runner.model, runner.work_dir) self.step(runner.model, runner.work_dir)
def step(self, model: nn.Module, work_dir: str): def step(self, model: nn.Module, work_dir: str) -> None:
"""Performs a dilation searching step. """Performs a dilation searching step.
Args: Args:
...@@ -132,7 +131,7 @@ class RFSearchHook(Hook): ...@@ -132,7 +131,7 @@ class RFSearchHook(Hook):
), ),
) )
def estimate_and_expand(self, model: nn.Module): def estimate_and_expand(self, model: nn.Module) -> None:
"""estimate and search for RFConvOp. """estimate and search for RFConvOp.
Args: Args:
...@@ -146,7 +145,7 @@ class RFSearchHook(Hook): ...@@ -146,7 +145,7 @@ class RFSearchHook(Hook):
def wrap_model(self, def wrap_model(self,
model: nn.Module, model: nn.Module,
search_op: str = 'Conv2d', search_op: str = 'Conv2d',
prefix: str = ''): prefix: str = '') -> None:
"""wrap model to support searchable conv op. """wrap model to support searchable conv op.
Args: Args:
...@@ -176,8 +175,9 @@ class RFSearchHook(Hook): ...@@ -176,8 +175,9 @@ class RFSearchHook(Hook):
module, self.config['search'], self.verbose) module, self.config['search'], self.verbose)
moduleWrap = moduleWrap.to(module.weight.device) moduleWrap = moduleWrap.to(module.weight.device)
if self.verbose: if self.verbose:
logger.info('Wrap model %s to %s.' % print_log(
(str(module), str(moduleWrap))) 'Wrap model %s to %s.' %
(str(module), str(moduleWrap)), 'current')
setattr(model, name, moduleWrap) setattr(model, name, moduleWrap)
elif not isinstance(module, BaseConvRFSearchOp): elif not isinstance(module, BaseConvRFSearchOp):
self.wrap_model(module, search_op, fullname) self.wrap_model(module, search_op, fullname)
...@@ -186,7 +186,7 @@ class RFSearchHook(Hook): ...@@ -186,7 +186,7 @@ class RFSearchHook(Hook):
model: nn.Module, model: nn.Module,
search_op: str = 'Conv2d', search_op: str = 'Conv2d',
init_rates: Optional[int] = None, init_rates: Optional[int] = None,
prefix: str = ''): prefix: str = '') -> None:
"""set model based on config. """set model based on config.
Args: Args:
...@@ -231,8 +231,9 @@ class RFSearchHook(Hook): ...@@ -231,8 +231,9 @@ class RFSearchHook(Hook):
self.config['structure'][fullname][1])) self.config['structure'][fullname][1]))
setattr(model, name, module) setattr(model, name, module)
if self.verbose: if self.verbose:
logger.info( print_log(
'Set module %s dilation as: [%d %d]' % 'Set module %s dilation as: [%d %d]' %
(fullname, module.dilation[0], module.dilation[1])) (fullname, module.dilation[0], module.dilation[1]),
'current')
elif not isinstance(module, BaseConvRFSearchOp): elif not isinstance(module, BaseConvRFSearchOp):
self.set_model(module, search_op, init_rates, fullname) self.set_model(module, search_op, init_rates, fullname)
...@@ -440,7 +440,7 @@ def imcrop( ...@@ -440,7 +440,7 @@ def imcrop(
img (ndarray): Image to be cropped. img (ndarray): Image to be cropped.
bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes. bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
scale (float, optional): Scale ratio of bboxes, the default value scale (float, optional): Scale ratio of bboxes, the default value
1.0 means no padding. 1.0 means no scaling.
pad_fill (Number | list[Number]): Value to be filled for padding. pad_fill (Number | list[Number]): Value to be filled for padding.
Default: None, which means no padding. Default: None, which means no padding.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
"""Tests the rfsearch with runners.
CommandLine:
pytest tests/test_runner/test_hooks.py
xdoctest tests/test_hooks.py zero
"""
import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader
from mmcv.cnn.rfsearch import Conv2dRFSearchOp, RFSearchHook from mmcv.cnn.rfsearch import Conv2dRFSearchOp, RFSearchHook
from tests.test_runner.test_hooks import _build_demo_runner
def test_rfsearchhook(): def test_rfsearchhook():
...@@ -114,20 +105,6 @@ def test_rfsearchhook(): ...@@ -114,20 +105,6 @@ def test_rfsearchhook():
assert model.conv2.dilation_rates == [(1, 1), (3, 3)] assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
assert model.conv3.dilation_rates == [(1, 1), (1, 2)] assert model.conv3.dilation_rates == [(1, 1), (1, 2)]
# 1. test step() with mode of search
loader = DataLoader(torch.ones((1, 1, 1, 1)))
runner = _build_demo_runner()
runner.model = model
runner.register_hook(rfsearchhook_search)
runner.run([loader], [('train', 1)])
test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
assert model.conv3.dilation_rates == [(1, 1), (1, 3)]
# 2. test init_model() with mode of fixed_single_branch # 2. test init_model() with mode of fixed_single_branch
model = Model() model = Model()
rfsearchhook_fixed_single_branch.init_model(model) rfsearchhook_fixed_single_branch.init_model(model)
...@@ -139,19 +116,6 @@ def test_rfsearchhook(): ...@@ -139,19 +116,6 @@ def test_rfsearchhook():
assert model.conv2.dilation == (2, 2) assert model.conv2.dilation == (2, 2)
assert model.conv3.dilation == (1, 1) assert model.conv3.dilation == (1, 1)
# 2. test step() with mode of fixed_single_branch
runner = _build_demo_runner()
runner.model = model
runner.register_hook(rfsearchhook_fixed_single_branch)
runner.run([loader], [('train', 1)])
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert not isinstance(model.conv2, Conv2dRFSearchOp)
assert not isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv1.dilation == (1, 1)
assert model.conv2.dilation == (2, 2)
assert model.conv3.dilation == (1, 1)
# 3. test init_model() with mode of fixed_multi_branch # 3. test init_model() with mode of fixed_multi_branch
model = Model() model = Model()
rfsearchhook_fixed_multi_branch.init_model(model) rfsearchhook_fixed_multi_branch.init_model(model)
...@@ -162,16 +126,3 @@ def test_rfsearchhook(): ...@@ -162,16 +126,3 @@ def test_rfsearchhook():
assert isinstance(model.conv3, Conv2dRFSearchOp) assert isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv2.dilation_rates == [(1, 1), (3, 3)] assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
assert model.conv3.dilation_rates == [(1, 1), (1, 2)] assert model.conv3.dilation_rates == [(1, 1), (1, 2)]
# 3. test step() with mode of fixed_single_branch
runner = _build_demo_runner()
runner.model = model
runner.register_hook(rfsearchhook_fixed_multi_branch)
runner.run([loader], [('train', 1)])
test_skip_layer()
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
assert model.conv3.dilation_rates == [(1, 1), (1, 2)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment