Unverified Commit 6e3827d0 authored by Shanghua Gao's avatar Shanghua Gao Committed by GitHub
Browse files

[Feature] Support receptive field search of CNN models (#2056)



* support rfsearch

* add labs for rfsearch

* format

* format

* add docstring and type hints

* clean code
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* rm unused func

* update code

* update code

* update code

* update  details

* fix details

* support asymmetric kernel

* support asymmetric kernel

* Apply suggestions from code review
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Apply suggestions from code review

* Apply suggestions from code review
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Apply suggestions from code review

* add unit tests for rfsearch

* set device for Conv2dRFSearchOp

* Apply suggestions from code review
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* remove unused function search_estimate_only

* move unit tests

* Update tests/test_cnn/test_rfsearch/test_operator.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/cnn/rfsearch/operator.py
Co-authored-by: default avatarYue Zhou <592267829@qq.com>

* change logger

* Update mmcv/cnn/rfsearch/operator.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarlzyhha <819814373@qq.com>
Co-authored-by: default avatarZhongyu Li <44114862+lzyhha@users.noreply.github.com>
Co-authored-by: default avatarYue Zhou <592267829@qq.com>
parent abb600ad
......@@ -14,6 +14,7 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
from .builder import MODELS, build_model_from_cfg
# yapf: enable
from .resnet import ResNet, make_res_layer
from .rfsearch import Conv2dRFSearchOp, RFSearchHook
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
XavierInit, bias_init_with_prob, caffe2_xavier_init,
......@@ -37,5 +38,6 @@ __all__ = [
'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg', 'Conv2dRFSearchOp',
'RFSearchHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp
from .search import RFSearchHook
__all__ = ['BaseConvRFSearchOp', 'Conv2dRFSearchOp', 'RFSearchHook']
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from mmcv.runner import BaseModule
from mmcv.utils.logging import get_logger
from .utils import expand_rates, get_single_padding
logger = get_logger('mmcv')
class BaseConvRFSearchOp(BaseModule):
"""Based class of ConvRFSearchOp.
Args:
op_layer (nn.Module): pytorch module, e,g, Conv2d
global_config (dict): config dict.
"""
def __init__(self, op_layer: nn.Module, global_config: dict):
super().__init__()
self.op_layer = op_layer
self.global_config = global_config
def normlize(self, weights: nn.Parameter) -> nn.Parameter:
"""Normalize weights.
Args:
weights (nn.Parameter): Weights to be normalized.
Returns:
nn.Parameters: Normalized weights.
"""
abs_weights = torch.abs(weights)
normalized_weights = abs_weights / torch.sum(abs_weights)
return normalized_weights
class Conv2dRFSearchOp(BaseConvRFSearchOp):
"""Enable Conv2d with receptive field searching ability.
Args:
op_layer (nn.Module): pytorch module, e,g, Conv2d
global_config (dict): config dict. Defaults to None.
By default this must include:
- "init_alphas": The value for initializing weights of each branch.
- "num_branches": The controller of the size of
search space (the number of branches).
- "exp_rate": The controller of the sparsity of search space.
- "mmin": The minimum dilation rate.
- "mmax": The maximum dilation rate.
Extra keys may exist, but are used by RFSearchHook, e.g., "step",
"max_step", "search_interval", and "skip_layer".
verbose (bool): Determines whether to print rf-next
related logging messages.
Defaults to True.
"""
def __init__(self,
op_layer: nn.Module,
global_config: dict,
verbose: bool = True):
super().__init__(op_layer, global_config)
assert global_config is not None, 'global_config is None'
self.num_branches = global_config['num_branches']
assert self.num_branches in [2, 3]
self.verbose = verbose
init_dilation = op_layer.dilation
self.dilation_rates = expand_rates(init_dilation, global_config)
if self.op_layer.kernel_size[
0] == 1 or self.op_layer.kernel_size[0] % 2 == 0:
self.dilation_rates = [(op_layer.dilation[0], r[1])
for r in self.dilation_rates]
if self.op_layer.kernel_size[
1] == 1 or self.op_layer.kernel_size[1] % 2 == 0:
self.dilation_rates = [(r[0], op_layer.dilation[1])
for r in self.dilation_rates]
self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches))
if self.verbose:
logger.info(f'Expand as {self.dilation_rates}')
nn.init.constant_(self.branch_weights, global_config['init_alphas'])
def forward(self, input: Tensor) -> Tensor:
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
if len(self.dilation_rates) == 1:
outputs = [
nn.functional.conv2d(
input,
weight=self.op_layer.weight,
bias=self.op_layer.bias,
stride=self.op_layer.stride,
padding=self.get_padding(self.dilation_rates[0]),
dilation=self.dilation_rates[0],
groups=self.op_layer.groups,
)
]
else:
outputs = [
nn.functional.conv2d(
input,
weight=self.op_layer.weight,
bias=self.op_layer.bias,
stride=self.op_layer.stride,
padding=self.get_padding(r),
dilation=r,
groups=self.op_layer.groups,
) * norm_w[i] for i, r in enumerate(self.dilation_rates)
]
output = outputs[0]
for i in range(1, len(self.dilation_rates)):
output += outputs[i]
return output
def estimate_rates(self):
"""Estimate new dilation rate based on trained branch_weights."""
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
if self.verbose:
logger.info('Estimate dilation {} with weight {}.'.format(
self.dilation_rates,
norm_w.detach().cpu().numpy().tolist()))
sum0, sum1, w_sum = 0, 0, 0
for i in range(len(self.dilation_rates)):
sum0 += norm_w[i].item() * self.dilation_rates[i][0]
sum1 += norm_w[i].item() * self.dilation_rates[i][1]
w_sum += norm_w[i].item()
estimated = [
np.clip(
int(round(sum0 / w_sum)), self.global_config['mmin'],
self.global_config['mmax']).item(),
np.clip(
int(round(sum1 / w_sum)), self.global_config['mmin'],
self.global_config['mmax']).item()
]
self.op_layer.dilation = tuple(estimated)
self.op_layer.padding = self.get_padding(self.op_layer.dilation)
self.dilation_rates = [tuple(estimated)]
if self.verbose:
logger.info(f'Estimate as {tuple(estimated)}')
def expand_rates(self):
"""Expand dilation rate."""
dilation = self.op_layer.dilation
dilation_rates = expand_rates(dilation, self.global_config)
if self.op_layer.kernel_size[
0] == 1 or self.op_layer.kernel_size[0] % 2 == 0:
dilation_rates = [(dilation[0], r[1]) for r in dilation_rates]
if self.op_layer.kernel_size[
1] == 1 or self.op_layer.kernel_size[1] % 2 == 0:
dilation_rates = [(r[0], dilation[1]) for r in dilation_rates]
self.dilation_rates = copy.deepcopy(dilation_rates)
if self.verbose:
logger.info(f'Expand as {self.dilation_rates}')
nn.init.constant_(self.branch_weights,
self.global_config['init_alphas'])
def get_padding(self, dilation):
padding = (get_single_padding(self.op_layer.kernel_size[0],
self.op_layer.stride[0], dilation[0]),
get_single_padding(self.op_layer.kernel_size[1],
self.op_layer.stride[1], dilation[1]))
return padding
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Dict, Optional
import torch # noqa
import torch.nn as nn
import mmcv
from mmcv.cnn.rfsearch.utils import get_single_padding, write_to_json
from mmcv.runner import HOOKS, Hook
from mmcv.utils import get_logger
from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp # noqa
logger = get_logger('mmcv')
@HOOKS.register_module()
class RFSearchHook(Hook):
"""Rcecptive field search via dilation rates.
Please refer to `RF-Next: Efficient Receptive Field
Search for Convolutional Neural Networks
<https://arxiv.org/abs/2206.06637>`_ for more details.
Args:
mode (str, optional): It can be set to the following types:
'search', 'fixed_single_branch', or 'fixed_multi_branch'.
Defaults to 'search'.
config (Dict, optional): config dict of search.
By default this config contains "search",
and config["search"] must include:
- "step": recording the current searching step.
- "max_step": The maximum number of searching steps
to update the structures.
- "search_interval": The interval (epoch/iteration)
between two updates.
- "exp_rate": The controller of the sparsity of search space.
- "init_alphas": The value for initializing weights of each branch.
- "mmin": The minimum dilation rate.
- "mmax": The maximum dilation rate.
- "num_branches": The controller of the size of
search space (the number of branches).
- "skip_layer": The modules in skip_layer will be ignored
during the receptive field search.
rfstructure_file (str, optional): Path to load searched receptive
fields of the model. Defaults to None.
by_epoch (bool, optional): Determine to perform step by epoch or
by iteration. If set to True, it will step by epoch. Otherwise, by
iteration. Defaults to True.
verbose (bool): Determines whether to print rf-next related logging
messages. Defaults to True.
"""
def __init__(self,
mode: str = 'search',
config: Dict = {},
rfstructure_file: Optional[str] = None,
by_epoch: bool = True,
verbose: bool = True):
assert mode in ['search', 'fixed_single_branch', 'fixed_multi_branch']
assert config is not None
self.config = config
self.config['structure'] = {}
self.verbose = verbose
if rfstructure_file is not None:
rfstructure = mmcv.load(rfstructure_file)['structure']
self.config['structure'] = rfstructure
self.mode = mode
self.num_branches = self.config['search']['num_branches']
self.by_epoch = by_epoch
def init_model(self, model: nn.Module):
"""init model with search ability.
Args:
model (nn.Module): pytorch model
Raises:
NotImplementedError: only support three modes:
search/fixed_single_branch/fixed_multi_branch
"""
if self.verbose:
logger.info('RFSearch init begin.')
if self.mode == 'search':
if self.config['structure']:
self.set_model(model, search_op='Conv2d')
self.wrap_model(model, search_op='Conv2d')
elif self.mode == 'fixed_single_branch':
self.set_model(model, search_op='Conv2d')
elif self.mode == 'fixed_multi_branch':
self.set_model(model, search_op='Conv2d')
self.wrap_model(model, search_op='Conv2d')
else:
raise NotImplementedError
if self.verbose:
logger.info('RFSearch init end.')
def after_train_epoch(self, runner):
"""Performs a dilation searching step after one training epoch."""
if self.by_epoch and self.mode == 'search':
self.step(runner.model, runner.work_dir)
def after_train_iter(self, runner):
"""Performs a dilation searching step after one training iteration."""
if not self.by_epoch and self.mode == 'search':
self.step(runner.model, runner.work_dir)
def step(self, model: nn.Module, work_dir: str):
"""Performs a dilation searching step.
Args:
model (nn.Module): pytorch model
work_dir (str): Directory to save the searching results.
"""
self.config['search']['step'] += 1
if (self.config['search']['step']
) % self.config['search']['search_interval'] == 0 and (self.config[
'search']['step']) < self.config['search']['max_step']:
self.estimate_and_expand(model)
for name, module in model.named_modules():
if isinstance(module, BaseConvRFSearchOp):
self.config['structure'][name] = module.op_layer.dilation
write_to_json(
self.config,
os.path.join(
work_dir,
'local_search_config_step%d.json' %
self.config['search']['step'],
),
)
def estimate_and_expand(self, model: nn.Module):
"""estimate and search for RFConvOp.
Args:
model (nn.Module): pytorch model
"""
for module in model.modules():
if isinstance(module, BaseConvRFSearchOp):
module.estimate_rates()
module.expand_rates()
def wrap_model(self, model: nn.Module, search_op: str = 'Conv2d'):
"""wrap model to support searchable conv op.
Args:
model (nn.Module): pytorch model
search_op (str): The module that uses RF search.
Defaults to 'Conv2d'.
init_rates (int, optional): Set to other initial dilation rates.
Defaults to None.
"""
op = 'torch.nn.' + search_op
for name, module in model.named_children():
if isinstance(module, eval(op)):
if 1 < module.kernel_size[0] and \
0 != module.kernel_size[0] % 2 or \
1 < module.kernel_size[1] and \
0 != module.kernel_size[1] % 2:
moduleWrap = eval(search_op + 'RFSearchOp')(
module, self.config['search'], self.verbose)
moduleWrap = moduleWrap.to(module.weight.device)
if self.verbose:
logger.info('Wrap model %s to %s.' %
(str(module), str(moduleWrap)))
setattr(model, name, moduleWrap)
elif isinstance(module, BaseConvRFSearchOp):
pass
else:
if self.config['search']['skip_layer'] is not None:
if any(layer in name
for layer in self.config['search']['skip_layer']):
continue
self.wrap_model(module, search_op)
def set_model(self,
model: nn.Module,
search_op: str = 'Conv2d',
init_rates: Optional[int] = None,
prefix: str = ''):
"""set model based on config.
Args:
model (nn.Module): pytorch model
config (Dict): config file
search_op (str): The module that uses RF search.
Defaults to 'Conv2d'.
init_rates (int, optional): Set to other initial dilation rates.
Defaults to None.
prefix (str): Prefix for function recursion. Defaults to ''.
"""
op = 'torch.nn.' + search_op
for name, module in model.named_children():
if prefix == '':
fullname = 'module.' + name
else:
fullname = prefix + '.' + name
if isinstance(module, eval(op)):
if 1 < module.kernel_size[0] and \
0 != module.kernel_size[0] % 2 or \
1 < module.kernel_size[1] and \
0 != module.kernel_size[1] % 2:
if isinstance(self.config['structure'][fullname], int):
self.config['structure'][fullname] = [
self.config['structure'][fullname],
self.config['structure'][fullname]
]
module.dilation = (
self.config['structure'][fullname][0],
self.config['structure'][fullname][1],
)
module.padding = (
get_single_padding(
module.kernel_size[0], module.stride[0],
self.config['structure'][fullname][0]),
get_single_padding(
module.kernel_size[1], module.stride[1],
self.config['structure'][fullname][1]))
setattr(model, name, module)
if self.verbose:
logger.info(
'Set module %s dilation as: [%d %d]' %
(fullname, module.dilation[0], module.dilation[1]))
elif isinstance(module, BaseConvRFSearchOp):
pass
else:
if self.config['search']['skip_layer'] is not None:
if any(layer in fullname
for layer in self.config['search']['skip_layer']):
continue
self.set_model(module, search_op, init_rates, fullname)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import mmcv
def write_to_json(config: dict, filename: str):
"""save config to json file.
Args:
config (dict): Config to be saved.
filename (str): Path to save config.
"""
with open(filename, 'w', encoding='utf-8') as f:
mmcv.dump(config, f, file_format='json')
def expand_rates(dilation: tuple, config: dict) -> list:
"""expand dilation rate according to config.
Args:
dilation (int): _description_
config (dict): config dict
Returns:
list: list of expanded dilation rates
"""
exp_rate = config['exp_rate']
large_rates = []
small_rates = []
for _ in range(config['num_branches'] // 2):
large_rates.append(
tuple([
np.clip(
int(round((1 + exp_rate) * dilation[0])), config['mmin'],
config['mmax']).item(),
np.clip(
int(round((1 + exp_rate) * dilation[1])), config['mmin'],
config['mmax']).item()
]))
small_rates.append(
tuple([
np.clip(
int(round((1 - exp_rate) * dilation[0])), config['mmin'],
config['mmax']).item(),
np.clip(
int(round((1 - exp_rate) * dilation[1])), config['mmin'],
config['mmax']).item()
]))
small_rates.reverse()
if config['num_branches'] % 2 == 0:
rate_list = small_rates + large_rates
else:
rate_list = small_rates + [dilation] + large_rates
unique_rate_list = list(set(rate_list))
unique_rate_list.sort(key=rate_list.index)
return unique_rate_list
def get_single_padding(kernel_size: int,
stride: int = 1,
dilation: int = 1) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import torch
import torch.nn as nn
from mmcv.cnn.rfsearch.operator import Conv2dRFSearchOp
global_config = dict(
step=0,
max_step=12,
search_interval=1,
exp_rate=0.5,
init_alphas=0.01,
mmin=1,
mmax=24,
num_branches=2,
skip_layer=['stem', 'layer1'])
# test with 3x3 conv
def test_rfsearch_operator_3x3():
conv = nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
operator = Conv2dRFSearchOp(conv, global_config)
x = torch.randn(1, 3, 32, 32)
# set no_grad to perform in-place operator
with torch.no_grad():
# After expand: (1, 1) (2, 2)
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (2, 2)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (2, 2)
assert operator.op_layer.dilation == (2, 2)
assert operator.op_layer.padding == (2, 2)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After expand: (1, 1) (3, 3)
operator.expand_rates()
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (3, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
operator.branch_weights[0] = 0.1
operator.branch_weights[1] = 0.4
# After estimate: (3, 3) with branch_weights of [0.2 0.8]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (3, 3)
assert operator.op_layer.dilation == (3, 3)
assert operator.op_layer.padding == (3, 3)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# test with 5x5 conv
def test_rfsearch_operator_5x5():
conv = nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=5, stride=1, padding=2)
operator = Conv2dRFSearchOp(conv, global_config)
x = torch.randn(1, 3, 32, 32)
with torch.no_grad():
# After expand: (1, 1) (2, 2)
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (2, 2)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (2, 2)
assert operator.op_layer.dilation == (2, 2)
assert operator.op_layer.padding == (4, 4)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After expand: (1, 1) (3, 3)
operator.expand_rates()
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (3, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
operator.branch_weights[0] = 0.1
operator.branch_weights[1] = 0.4
# After estimate: (3, 3) with branch_weights of [0.2 0.8]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (3, 3)
assert operator.op_layer.dilation == (3, 3)
assert operator.op_layer.padding == (6, 6)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# test with 5x5 conv num_branches=3
def test_rfsearch_operator_5x5_branch3():
conv = nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=5, stride=1, padding=2)
config = deepcopy(global_config)
config['num_branches'] = 3
operator = Conv2dRFSearchOp(conv, config)
x = torch.randn(1, 3, 32, 32)
with torch.no_grad():
# After expand: (1, 1) (2, 2)
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (2, 2)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (2, 2)
assert operator.op_layer.dilation == (2, 2)
assert operator.op_layer.padding == (4, 4)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After expand: (1, 1) (2, 2) (3, 3)
operator.expand_rates()
assert len(operator.dilation_rates) == 3
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (2, 2)
assert operator.dilation_rates[2] == (3, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
operator.branch_weights[0] = 0.1
operator.branch_weights[1] = 0.3
operator.branch_weights[2] = 0.6
# After estimate: (3, 3) with branch_weights of [0.1 0.3 0.6]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (3, 3)
assert operator.op_layer.dilation == (3, 3)
assert operator.op_layer.padding == (6, 6)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# test with 1x5 conv
def test_rfsearch_operator_1x5():
conv = nn.Conv2d(
in_channels=3,
out_channels=3,
kernel_size=(1, 5),
stride=1,
padding=(0, 2))
operator = Conv2dRFSearchOp(conv, global_config)
x = torch.randn(1, 3, 32, 32)
# After expand: (1, 1) (1, 2)
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (1, 2)
assert torch.all(
operator.branch_weights.data == global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
with torch.no_grad():
# After estimate: (1, 2) with branch_weights of [0.5 0.5]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (1, 2)
assert operator.op_layer.dilation == (1, 2)
assert operator.op_layer.padding == (0, 4)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After expand: (1, 1) (1, 3)
operator.expand_rates()
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (1, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
operator.branch_weights[0] = 0.2
operator.branch_weights[1] = 0.8
# After estimate: (3, 3) with branch_weights of [0.2 0.8]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (1, 3)
assert operator.op_layer.dilation == (1, 3)
assert operator.op_layer.padding == (0, 6)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# test with 5x5 conv initial_dilation=(2, 2)
def test_rfsearch_operator_5x5_d2x2():
conv = nn.Conv2d(
in_channels=3,
out_channels=3,
kernel_size=5,
stride=1,
padding=4,
dilation=(2, 2))
operator = Conv2dRFSearchOp(conv, global_config)
x = torch.randn(1, 3, 32, 32)
with torch.no_grad():
# After expand: (1, 1) (3, 3)
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (3, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (2, 2)
assert operator.op_layer.dilation == (2, 2)
assert operator.op_layer.padding == (4, 4)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After expand: (1, 1) (3, 3)
operator.expand_rates()
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (3, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
operator.branch_weights[0] = 0.8
operator.branch_weights[1] = 0.2
# After estimate: (3, 3) with branch_weights of [0.8 0.2]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (1, 1)
assert operator.op_layer.dilation == (1, 1)
assert operator.op_layer.padding == (2, 2)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# test with 5x5 conv initial_dilation=(1, 2)
def test_rfsearch_operator_5x5_d1x2():
conv = nn.Conv2d(
in_channels=3,
out_channels=3,
kernel_size=5,
stride=1,
padding=(2, 4),
dilation=(1, 2))
operator = Conv2dRFSearchOp(conv, global_config)
x = torch.randn(1, 3, 32, 32)
with torch.no_grad():
# After expand: (1, 1) (2, 3)
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (2, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After estimate: (2, 2) with branch_weights of [0.5 0.5]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (2, 2)
assert operator.op_layer.dilation == (2, 2)
assert operator.op_layer.padding == (4, 4)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# After expand: (1, 1) (3, 3)
operator.expand_rates()
assert len(operator.dilation_rates) == 2
assert operator.dilation_rates[0] == (1, 1)
assert operator.dilation_rates[1] == (3, 3)
assert torch.all(operator.branch_weights.data ==
global_config['init_alphas']).item()
# test forward
assert operator(x).shape == (1, 3, 32, 32)
operator.branch_weights[0] = 0.1
operator.branch_weights[1] = 0.8
# After estimate: (3, 3) with branch_weights of [0.1 0.8]
operator.estimate_rates()
assert len(operator.dilation_rates) == 1
assert operator.dilation_rates[0] == (3, 3)
assert operator.op_layer.dilation == (3, 3)
assert operator.op_layer.padding == (6, 6)
# test forward
assert operator(x).shape == (1, 3, 32, 32)
# Copyright (c) OpenMMLab. All rights reserved.
"""Tests the rfsearch with runners.
CommandLine:
pytest tests/test_runner/test_hooks.py
xdoctest tests/test_hooks.py zero
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from mmcv.cnn.rfsearch import Conv2dRFSearchOp, RFSearchHook
from tests.test_runner.test_hooks import _build_demo_runner
def test_rfsearchhook():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=1,
stride=1,
padding=0,
dilation=1)
self.conv2 = nn.Conv2d(
in_channels=2,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
self.conv3 = nn.Conv2d(
in_channels=1,
out_channels=2,
kernel_size=(1, 3),
stride=1,
padding=(0, 1),
dilation=1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x2
def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x).mean(), num_samples=x.shape[0])
rfsearch_cfg = dict(
mode='search',
rfstructure_file=None,
config=dict(
search=dict(
step=0,
max_step=12,
search_interval=1,
exp_rate=0.5,
init_alphas=0.01,
mmin=1,
mmax=24,
num_branches=2,
skip_layer=['stem', 'layer1'])),
)
# hook for search
rfsearchhook_search = RFSearchHook(
'search', rfsearch_cfg['config'], by_epoch=True, verbose=True)
rfsearchhook_search.config['structure'] = {
'module.conv2': [2, 2],
'module.conv3': [1, 1]
}
# hook for fixed_single_branch
rfsearchhook_fixed_single_branch = RFSearchHook(
'fixed_single_branch',
rfsearch_cfg['config'],
by_epoch=True,
verbose=True)
rfsearchhook_fixed_single_branch.config['structure'] = {
'module.conv2': [2, 2],
'module.conv3': [1, 1]
}
# hook for fixed_multi_branch
rfsearchhook_fixed_multi_branch = RFSearchHook(
'fixed_multi_branch',
rfsearch_cfg['config'],
by_epoch=True,
verbose=True)
rfsearchhook_fixed_multi_branch.config['structure'] = {
'module.conv2': [2, 2],
'module.conv3': [1, 1]
}
# 1. test init_model() with mode of search
model = Model()
rfsearchhook_search.init_model(model)
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
assert model.conv3.dilation_rates == [(1, 1), (1, 2)]
# 1. test step() with mode of search
loader = DataLoader(torch.ones((1, 1, 1, 1)))
runner = _build_demo_runner()
runner.model = model
runner.register_hook(rfsearchhook_search)
runner.run([loader], [('train', 1)])
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
assert model.conv3.dilation_rates == [(1, 1), (1, 3)]
# 2. test init_model() with mode of fixed_single_branch
model = Model()
rfsearchhook_fixed_single_branch.init_model(model)
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert not isinstance(model.conv2, Conv2dRFSearchOp)
assert not isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv1.dilation == (1, 1)
assert model.conv2.dilation == (2, 2)
assert model.conv3.dilation == (1, 1)
# 2. test step() with mode of fixed_single_branch
runner = _build_demo_runner()
runner.model = model
runner.register_hook(rfsearchhook_fixed_single_branch)
runner.run([loader], [('train', 1)])
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert not isinstance(model.conv2, Conv2dRFSearchOp)
assert not isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv1.dilation == (1, 1)
assert model.conv2.dilation == (2, 2)
assert model.conv3.dilation == (1, 1)
# 3. test init_model() with mode of fixed_multi_branch
model = Model()
rfsearchhook_fixed_multi_branch.init_model(model)
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
assert model.conv3.dilation_rates == [(1, 1), (1, 2)]
# 3. test step() with mode of fixed_single_branch
runner = _build_demo_runner()
runner.model = model
runner.register_hook(rfsearchhook_fixed_multi_branch)
runner.run([loader], [('train', 1)])
assert not isinstance(model.conv1, Conv2dRFSearchOp)
assert isinstance(model.conv2, Conv2dRFSearchOp)
assert isinstance(model.conv3, Conv2dRFSearchOp)
assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
assert model.conv3.dilation_rates == [(1, 1), (1, 2)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment