Unverified Commit 604f8431 authored by Houwen Peng's avatar Houwen Peng Committed by GitHub
Browse files

Request for Integrating the new NAS algorithm: Cream (#2705)

parent cda02aff
from lib.utils.util import *
from timm.models.efficientnet_blocks import *
class ChildNetBuilder:
def __init__(
self,
channel_multiplier=1.0,
channel_divisor=8,
channel_min=None,
output_stride=32,
pad_type='',
act_layer=None,
se_kwargs=None,
norm_layer=nn.BatchNorm2d,
norm_kwargs=None,
drop_path_rate=0.,
feature_location='',
verbose=False,
logger=None):
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
self.output_stride = output_stride
self.pad_type = pad_type
self.act_layer = act_layer
self.se_kwargs = se_kwargs
self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs
self.drop_path_rate = drop_path_rate
self.feature_location = feature_location
assert feature_location in ('pre_pwl', 'post_exp', '')
self.verbose = verbose
self.in_chs = None
self.features = OrderedDict()
self.logger = logger
def _round_channels(self, chs):
return round_channels(
chs,
self.channel_multiplier,
self.channel_divisor,
self.channel_min)
def _make_block(self, ba, block_idx, block_count):
drop_path_rate = self.drop_path_rate * block_idx / block_count
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
if 'fake_in_chs' in ba and ba['fake_in_chs']:
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
ba['norm_layer'] = self.norm_layer
ba['norm_kwargs'] = self.norm_kwargs
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
assert ba['act_layer'] is not None
if bt == 'ir':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
self.logger.info(
' InvertedResidual {}, Args: {}'.format(
block_idx, str(ba)))
block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
self.logger.info(
' DepthwiseSeparable {}, Args: {}'.format(
block_idx, str(ba)))
block = DepthwiseSeparableConv(**ba)
elif bt == 'cn':
if self.verbose:
self.logger.info(
' ConvBnAct {}, Args: {}'.format(
block_idx, str(ba)))
block = ConvBnAct(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
return block
def __call__(self, in_chs, model_block_args):
""" Build the blocks
Args:
in_chs: Number of input-channels passed to first block
model_block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s)
Return:
List of block stacks (each stack wrapped in nn.Sequential)
"""
if self.verbose:
self.logger.info(
'Building model trunk with %d stages...' %
len(model_block_args))
self.in_chs = in_chs
total_block_count = sum([len(x) for x in model_block_args])
total_block_idx = 0
current_stride = 2
current_dilation = 1
feature_idx = 0
stages = []
# outer list of block_args defines the stacks ('stages' by some
# conventions)
for stage_idx, stage_block_args in enumerate(model_block_args):
last_stack = stage_idx == (len(model_block_args) - 1)
if self.verbose:
self.logger.info('Stack: {}'.format(stage_idx))
assert isinstance(stage_block_args, list)
blocks = []
# each stack (stage) contains a list of block arguments
for block_idx, block_args in enumerate(stage_block_args):
last_block = block_idx == (len(stage_block_args) - 1)
extract_features = '' # No features extracted
if self.verbose:
self.logger.info(' Block: {}'.format(block_idx))
# Sort out stride, dilation, and feature extraction details
assert block_args['stride'] in (1, 2)
if block_idx >= 1:
# only the first block in any stack can have a stride > 1
block_args['stride'] = 1
do_extract = False
if self.feature_location == 'pre_pwl':
if last_block:
next_stage_idx = stage_idx + 1
if next_stage_idx >= len(model_block_args):
do_extract = True
else:
do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
elif self.feature_location == 'post_exp':
if block_args['stride'] > 1 or (last_stack and last_block):
do_extract = True
if do_extract:
extract_features = self.feature_location
next_dilation = current_dilation
if block_args['stride'] > 1:
next_output_stride = current_stride * block_args['stride']
if next_output_stride > self.output_stride:
next_dilation = current_dilation * block_args['stride']
block_args['stride'] = 1
if self.verbose:
self.logger.info(
' Converting stride to dilation to maintain output_stride=={}'.format(
self.output_stride))
else:
current_stride = next_output_stride
block_args['dilation'] = current_dilation
if next_dilation != current_dilation:
current_dilation = next_dilation
# create the block
block = self._make_block(
block_args, total_block_idx, total_block_count)
blocks.append(block)
# stash feature module name and channel info for model feature
# extraction
if extract_features:
feature_module = block.feature_module(extract_features)
if feature_module:
feature_module = 'blocks.{}.{}.'.format(
stage_idx, block_idx) + feature_module
feature_channels = block.feature_channels(extract_features)
self.features[feature_idx] = dict(
name=feature_module,
num_chs=feature_channels
)
feature_idx += 1
# incr global block idx (across all stacks)
total_block_idx += 1
stages.append(nn.Sequential(*blocks))
return stages
from copy import deepcopy
from lib.utils.builder_util import modify_block_args
from lib.models.blocks import get_Bottleneck, InvertedResidual
from timm.models.efficientnet_blocks import *
from nni.nas.pytorch import mutables
class SuperNetBuilder:
""" Build Trunk Blocks
"""
def __init__(
self,
choices,
channel_multiplier=1.0,
channel_divisor=8,
channel_min=None,
output_stride=32,
pad_type='',
act_layer=None,
se_kwargs=None,
norm_layer=nn.BatchNorm2d,
norm_kwargs=None,
drop_path_rate=0.,
feature_location='',
verbose=False,
resunit=False,
dil_conv=False,
logger=None):
# dict
# choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]}
self.choices = [[x, y] for x in choices['kernel_size']
for y in choices['exp_ratio']]
self.choices_num = len(self.choices) - 1
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
self.output_stride = output_stride
self.pad_type = pad_type
self.act_layer = act_layer
self.se_kwargs = se_kwargs
self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs
self.drop_path_rate = drop_path_rate
self.feature_location = feature_location
assert feature_location in ('pre_pwl', 'post_exp', '')
self.verbose = verbose
self.resunit = resunit
self.dil_conv = dil_conv
self.logger = logger
# state updated during build, consumed by model
self.in_chs = None
def _round_channels(self, chs):
return round_channels(
chs,
self.channel_multiplier,
self.channel_divisor,
self.channel_min)
def _make_block(
self,
ba,
choice_idx,
block_idx,
block_count,
resunit=False,
dil_conv=False):
drop_path_rate = self.drop_path_rate * block_idx / block_count
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
if 'fake_in_chs' in ba and ba['fake_in_chs']:
# FIXME this is a hack to work around mismatch in origin impl input
# filters
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
ba['norm_layer'] = self.norm_layer
ba['norm_kwargs'] = self.norm_kwargs
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
assert ba['act_layer'] is not None
if bt == 'ir':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
self.logger.info(
' InvertedResidual {}, Args: {}'.format(
block_idx, str(ba)))
block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
self.logger.info(
' DepthwiseSeparable {}, Args: {}'.format(
block_idx, str(ba)))
block = DepthwiseSeparableConv(**ba)
elif bt == 'cn':
if self.verbose:
self.logger.info(
' ConvBnAct {}, Args: {}'.format(
block_idx, str(ba)))
block = ConvBnAct(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
if choice_idx == self.choice_num - 1:
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
return block
def __call__(self, in_chs, model_block_args):
""" Build the blocks
Args:
in_chs: Number of input-channels passed to first block
model_block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s)
Return:
List of block stacks (each stack wrapped in nn.Sequential)
"""
if self.verbose:
logging.info('Building model trunk with %d stages...' % len(model_block_args))
self.in_chs = in_chs
total_block_count = sum([len(x) for x in model_block_args])
total_block_idx = 0
current_stride = 2
current_dilation = 1
feature_idx = 0
stages = []
# outer list of block_args defines the stacks ('stages' by some conventions)
for stage_idx, stage_block_args in enumerate(model_block_args):
last_stack = stage_idx == (len(model_block_args) - 1)
if self.verbose:
self.logger.info('Stack: {}'.format(stage_idx))
assert isinstance(stage_block_args, list)
# blocks = []
# each stack (stage) contains a list of block arguments
for block_idx, block_args in enumerate(stage_block_args):
last_block = block_idx == (len(stage_block_args) - 1)
if self.verbose:
self.logger.info(' Block: {}'.format(block_idx))
# Sort out stride, dilation, and feature extraction details
assert block_args['stride'] in (1, 2)
if block_idx >= 1:
# only the first block in any stack can have a stride > 1
block_args['stride'] = 1
next_dilation = current_dilation
if block_args['stride'] > 1:
next_output_stride = current_stride * block_args['stride']
if next_output_stride > self.output_stride:
next_dilation = current_dilation * block_args['stride']
block_args['stride'] = 1
else:
current_stride = next_output_stride
block_args['dilation'] = current_dilation
if next_dilation != current_dilation:
current_dilation = next_dilation
if stage_idx==0 or stage_idx==6:
self.choice_num = 1
else:
self.choice_num = len(self.choices)
if self.dil_conv:
self.choice_num += 2
choice_blocks = []
block_args_copy = deepcopy(block_args)
if self.choice_num == 1:
# create the block
block = self._make_block(block_args, 0, total_block_idx, total_block_count)
choice_blocks.append(block)
else:
for choice_idx, choice in enumerate(self.choices):
# create the block
block_args = deepcopy(block_args_copy)
block_args = modify_block_args(block_args, choice[0], choice[1])
block = self._make_block(block_args, choice_idx, total_block_idx, total_block_count)
choice_blocks.append(block)
if self.dil_conv:
block_args = deepcopy(block_args_copy)
block_args = modify_block_args(block_args, 3, 0)
block = self._make_block(block_args, self.choice_num - 2, total_block_idx, total_block_count,
resunit=self.resunit, dil_conv=self.dil_conv)
choice_blocks.append(block)
block_args = deepcopy(block_args_copy)
block_args = modify_block_args(block_args, 5, 0)
block = self._make_block(block_args, self.choice_num - 1, total_block_idx, total_block_count,
resunit=self.resunit, dil_conv=self.dil_conv)
choice_blocks.append(block)
if self.resunit:
block = get_Bottleneck(block.conv_pw.in_channels,
block.conv_pwl.out_channels,
block.conv_dw.stride[0])
choice_blocks.append(block)
choice_block = mutables.LayerChoice(choice_blocks)
stages.append(choice_block)
# create the block
# block = self._make_block(block_args, total_block_idx, total_block_count)
total_block_idx += 1 # incr global block idx (across all stacks)
# stages.append(blocks)
return stages
from lib.utils.builder_util import *
from lib.models.builders.build_childnet import *
from timm.models.layers import SelectAdaptivePool2d
from timm.models.layers.activations import hard_sigmoid
class ChildNet(nn.Module):
def __init__(
self,
block_args,
num_classes=1000,
in_chans=3,
stem_size=16,
num_features=1280,
head_bias=True,
channel_multiplier=1.0,
pad_type='',
act_layer=nn.ReLU,
drop_rate=0.,
drop_path_rate=0.,
se_kwargs=None,
norm_layer=nn.BatchNorm2d,
norm_kwargs=None,
global_pool='avg',
logger=None,
verbose=False):
super(ChildNet, self).__init__()
self.num_classes = num_classes
self.num_features = num_features
self.drop_rate = drop_rate
self._in_chs = in_chans
self.logger = logger
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = create_conv2d(
self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size
# Middle stages (IR/ER/DS Blocks)
builder = ChildNetBuilder(
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, verbose=verbose)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
# self.blocks = builder(self._in_chs, block_args)
self._in_chs = builder.in_chs
# Head + Pooling
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.conv_head = create_conv2d(
self._in_chs,
self.num_features,
1,
padding=pad_type,
bias=head_bias)
self.act2 = act_layer(inplace=True)
# Classifier
self.classifier = nn.Linear(
self.num_features *
self.global_pool.feat_mult(),
self.num_classes)
efficientnet_init_weights(self)
def get_classifier(self):
return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
self.classifier = nn.Linear(
self.num_features * self.global_pool.feat_mult(),
num_classes) if self.num_classes else None
def forward_features(self, x):
# architecture = [[0], [], [], [], [], [0]]
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
x = self.blocks(x)
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = x.flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.classifier(x)
return x
def gen_childnet(arch_list, arch_def, **kwargs):
# arch_list = [[0], [], [], [], [], [0]]
choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]}
choices_list = [[x, y] for x in choices['kernel_size']
for y in choices['exp_ratio']]
num_features = 1280
# act_layer = HardSwish
act_layer = Swish
new_arch = []
# change to child arch_def
for i, (layer_choice, layer_arch) in enumerate(zip(arch_list, arch_def)):
if len(layer_arch) == 1:
new_arch.append(layer_arch)
continue
else:
new_layer = []
for j, (block_choice, block_arch) in enumerate(
zip(layer_choice, layer_arch)):
kernel_size, exp_ratio = choices_list[block_choice]
elements = block_arch.split('_')
block_arch = block_arch.replace(
elements[2], 'k{}'.format(str(kernel_size)))
block_arch = block_arch.replace(
elements[4], 'e{}'.format(str(exp_ratio)))
new_layer.append(block_arch)
new_arch.append(new_layer)
model_kwargs = dict(
block_args=decode_arch_def(new_arch),
num_features=num_features,
stem_size=16,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=act_layer,
se_kwargs=dict(
act_layer=nn.ReLU,
gate_fn=hard_sigmoid,
reduce_mid=True,
divisor=8),
**kwargs,
)
model = ChildNet(**model_kwargs)
return model
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
from lib.utils.builder_util import *
from lib.utils.search_structure_supernet import *
from lib.models.builders.build_supernet import *
from lib.utils.op_by_layer_dict import flops_op_dict
from timm.models.layers import SelectAdaptivePool2d
from timm.models.layers.activations import hard_sigmoid
class SuperNet(nn.Module):
def __init__(
self,
block_args,
choices,
num_classes=1000,
in_chans=3,
stem_size=16,
num_features=1280,
head_bias=True,
channel_multiplier=1.0,
pad_type='',
act_layer=nn.ReLU,
drop_rate=0.,
drop_path_rate=0.,
slice=4,
se_kwargs=None,
norm_layer=nn.BatchNorm2d,
logger=None,
norm_kwargs=None,
global_pool='avg',
resunit=False,
dil_conv=False,
verbose=False):
super(SuperNet, self).__init__()
self.num_classes = num_classes
self.num_features = num_features
self.drop_rate = drop_rate
self._in_chs = in_chans
self.logger = logger
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = create_conv2d(
self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size
# Middle stages (IR/ER/DS Blocks)
builder = SuperNetBuilder(
choices,
channel_multiplier,
8,
None,
32,
pad_type,
act_layer,
se_kwargs,
norm_layer,
norm_kwargs,
drop_path_rate,
verbose=verbose,
resunit=resunit,
dil_conv=dil_conv,
logger=self.logger)
blocks = builder(self._in_chs, block_args)
self.blocks = nn.Sequential(*blocks)
self._in_chs = builder.in_chs
# Head + Pooling
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.conv_head = create_conv2d(
self._in_chs,
self.num_features,
1,
padding=pad_type,
bias=head_bias)
self.act2 = act_layer(inplace=True)
# Classifier
self.classifier = nn.Linear(
self.num_features *
self.global_pool.feat_mult(),
self.num_classes)
self.meta_layer = nn.Linear(self.num_classes * slice, 1)
efficientnet_init_weights(self)
def get_classifier(self):
return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
self.classifier = nn.Linear(
self.num_features * self.global_pool.feat_mult(),
num_classes) if self.num_classes else None
def forward_features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
x = self.blocks(x)
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = x.flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
def forward_meta(self, features):
return self.meta_layer(features.view(1, -1))
def rand_parameters(self, architecture, meta=False):
for name, param in self.named_parameters(recurse=True):
if 'meta' in name and meta:
yield param
elif 'blocks' not in name and 'meta' not in name and (not meta):
yield param
if not meta:
for layer, layer_arch in zip(self.blocks, architecture):
for blocks, arch in zip(layer, layer_arch):
if arch == -1:
continue
for name, param in blocks[arch].named_parameters(
recurse=True):
yield param
class Classifier(nn.Module):
def __init__(self, num_classes=1000):
super(Classifier, self).__init__()
self.classifier = nn.Linear(num_classes, num_classes)
def forward(self, x):
return self.classifier(x)
def gen_supernet(flops_minimum=0, flops_maximum=600, **kwargs):
choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]}
num_features = 1280
# act_layer = HardSwish
act_layer = Swish
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_se0.25'],
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25',
'ir_r1_k3_s1_e4_c24_se0.25'],
# stage 2, 56x56 in
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s1_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25',
'ir_r1_k5_s2_e4_c40_se0.25'],
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25',
'ir_r2_k3_s1_e4_c80_se0.25'],
# stage 4, 14x14in
['ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
'ir_r1_k3_s1_e6_c96_se0.25'],
# stage 5, 14x14in
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25',
'ir_r1_k5_s2_e6_c192_se0.25'],
# stage 6, 7x7 in
['cn_r1_k1_s1_c320_se0.25'],
]
sta_num, arch_def, resolution = search_for_layer(
flops_op_dict, arch_def, flops_minimum, flops_maximum)
if sta_num is None or arch_def is None or resolution is None:
raise ValueError('Invalid FLOPs Settings')
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
choices=choices,
num_features=num_features,
stem_size=16,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=act_layer,
se_kwargs=dict(
act_layer=nn.ReLU,
gate_fn=hard_sigmoid,
reduce_mid=True,
divisor=8),
**kwargs,
)
model = SuperNet(**model_kwargs)
return model, sta_num, resolution
import math
import torch.nn as nn
from timm.utils import *
from timm.models.layers.activations import Swish
from timm.models.layers import CondConv2d, get_condconv_initializer
def parse_ksize(ss):
if ss.isdigit():
return int(ss)
else:
return [int(k) for k in ss.split('.')]
def decode_arch_def(
arch_def,
depth_multiplier=1.0,
depth_trunc='ceil',
experts_multiplier=1):
arch_args = []
for stack_idx, block_strings in enumerate(arch_def):
assert isinstance(block_strings, list)
stack_args = []
repeats = []
for block_str in block_strings:
assert isinstance(block_str, str)
ba, rep = decode_block_str(block_str)
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
ba['num_experts'] *= experts_multiplier
stack_args.append(ba)
repeats.append(rep)
arch_args.append(
scale_stage_depth(
stack_args,
repeats,
depth_multiplier,
depth_trunc))
return arch_args
def modify_block_args(block_args, kernel_size, exp_ratio):
block_type = block_args['block_type']
if block_type == 'cn':
block_args['kernel_size'] = kernel_size
elif block_type == 'er':
block_args['exp_kernel_size'] = kernel_size
else:
block_args['dw_kernel_size'] = kernel_size
if block_type == 'ir' or block_type == 'er':
block_args['exp_ratio'] = exp_ratio
return block_args
def decode_block_str(block_str):
""" Decode block definition string
Gets a list of block arg (dicts) through a string notation of arguments.
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
All args can exist in any order with the exception of the leading string which
is assumed to indicate the block type.
leading string - block type (
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
r - number of repeat blocks,
k - kernel size,
s - strides (1-9),
e - expansion ratio,
c - output channels,
se - squeeze/excitation ratio
n - activation fn ('re', 'r6', 'hs', or 'sw')
Args:
block_str: a string representation of block arguments.
Returns:
A list of block args (dicts)
Raises:
ValueError: if the string def not properly specified (TODO)
"""
assert isinstance(block_str, str)
ops = block_str.split('_')
block_type = ops[0] # take the block type off the front
ops = ops[1:]
options = {}
noskip = False
for op in ops:
# string options being checked on individual basis, combine if they
# grow
if op == 'noskip':
noskip = True
elif op.startswith('n'):
# activation fn
key = op[0]
v = op[1:]
if v == 're':
value = nn.ReLU
elif v == 'r6':
value = nn.ReLU6
elif v == 'sw':
value = Swish
else:
continue
options[key] = value
else:
# all numeric options
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# if act_layer is None, the model default (passed to model init) will be
# used
act_layer = options['n'] if 'n' in options else None
exp_kernel_size = parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = parse_ksize(options['p']) if 'p' in options else 1
# FIXME hack to deal with in_chs issue in TPU def
fake_in_chs = int(options['fc']) if 'fc' in options else 0
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
if block_type == 'ir':
block_args = dict(
block_type=block_type,
dw_kernel_size=parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
noskip=noskip,
)
if 'cc' in options:
block_args['num_experts'] = int(options['cc'])
elif block_type == 'ds' or block_type == 'dsa':
block_args = dict(
block_type=block_type,
dw_kernel_size=parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
)
elif block_type == 'cn':
block_args = dict(
block_type=block_type,
kernel_size=int(options['k']),
out_chs=int(options['c']),
stride=int(options['s']),
act_layer=act_layer,
)
else:
assert False, 'Unknown block type (%s)' % block_type
return block_args, num_repeat
def scale_stage_depth(
stack_args,
repeats,
depth_multiplier=1.0,
depth_trunc='ceil'):
""" Per-stage depth scaling
Scales the block repeats in each stage. This depth scaling impl maintains
compatibility with the EfficientNet scaling method, while allowing sensible
scaling for other models that may have multiple block arg definitions in each stage.
"""
# We scale the total repeat count for each stage, there may be multiple
# block arg defs per stage so we need to sum.
num_repeat = sum(repeats)
if depth_trunc == 'round':
# Truncating to int by rounding allows stages with few repeats to remain
# proportionally smaller for longer. This is a good choice when stage definitions
# include single repeat stages that we'd prefer to keep that way as
# long as possible
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
else:
# The default for EfficientNet truncates repeats to int via 'ceil'.
# Any multiplier > 1.0 will result in an increased depth for every
# stage.
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
# Proportionally distribute repeat count scaling to each block definition in the stage.
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
# The first block makes less sense to repeat in most of the arch
# definitions.
repeats_scaled = []
for r in repeats[::-1]:
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
repeats_scaled.append(rs)
num_repeat -= r
num_repeat_scaled -= rs
repeats_scaled = repeats_scaled[::-1]
# Apply the calculated scaling to each block arg in the stage
sa_scaled = []
for ba, rep in zip(stack_args, repeats_scaled):
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
return sa_scaled
def init_weight_goog(m, n='', fix_group_fanout=True, last_bn=None):
""" Weight initialization as per Tensorflow official implementations.
Args:
m (nn.Module): module to init
n (str): module name
fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
"""
if isinstance(m, CondConv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
init_weight_fn = get_condconv_initializer(lambda w: w.data.normal_(
0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
init_weight_fn(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
if n in last_bn:
m.weight.data.zero_()
m.bias.data.zero_()
else:
m.weight.data.fill_(1.0)
m.bias.data.zero_()
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
fan_out = m.weight.size(0) # fan-out
fan_in = 0
if 'routing_fn' in n:
fan_in = m.weight.size(1)
init_range = 1.0 / math.sqrt(fan_in + fan_out)
m.weight.data.uniform_(-init_range, init_range)
m.bias.data.zero_()
def efficientnet_init_weights(
model: nn.Module,
init_fn=None,
zero_gamma=False):
last_bn = []
if zero_gamma:
prev_n = ''
for n, m in model.named_modules():
if isinstance(m, nn.BatchNorm2d):
if ''.join(
prev_n.split('.')[
:-
1]) != ''.join(
n.split('.')[
:-
1]):
last_bn.append(prev_n)
prev_n = n
last_bn.append(prev_n)
init_fn = init_fn or init_weight_goog
for n, m in model.named_modules():
init_fn(m, n, last_bn=last_bn)
init_fn(m, n, last_bn=last_bn)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
import torch
from ptflops import get_model_complexity_info
class FlopsEst(object):
def __init__(self, model, input_shape=(2, 3, 224, 224), device='cpu'):
self.block_num = len(model.blocks)
self.choice_num = len(model.blocks[0])
self.flops_dict = {}
self.params_dict = {}
if device == 'cpu':
model = model.cpu()
else:
model = model.cuda()
self.params_fixed = 0
self.flops_fixed = 0
input = torch.randn(input_shape)
flops, params = get_model_complexity_info(
model.conv_stem, (3, 224, 224), as_strings=False, print_per_layer_stat=False)
self.params_fixed += params / 1e6
self.flops_fixed += flops / 1e6
input = model.conv_stem(input)
for block_id, block in enumerate(model.blocks):
self.flops_dict[block_id] = {}
self.params_dict[block_id] = {}
for module_id, module in enumerate(block):
flops, params = get_model_complexity_info(module, tuple(
input.shape[1:]), as_strings=False, print_per_layer_stat=False)
# Flops(M)
self.flops_dict[block_id][module_id] = flops / 1e6
# Params(M)
self.params_dict[block_id][module_id] = params / 1e6
input = module(input)
# conv_last
flops, params = get_model_complexity_info(model.global_pool, tuple(
input.shape[1:]), as_strings=False, print_per_layer_stat=False)
self.params_fixed += params / 1e6
self.flops_fixed += flops / 1e6
input = model.global_pool(input)
# globalpool
flops, params = get_model_complexity_info(model.conv_head, tuple(
input.shape[1:]), as_strings=False, print_per_layer_stat=False)
self.params_fixed += params / 1e6
self.flops_fixed += flops / 1e6
# return params (M)
def get_params(self, arch):
params = 0
for block_id, block in enumerate(arch):
if block == -1:
continue
params += self.params_dict[block_id][block]
return params + self.params_fixed
# return flops (M)
def get_flops(self, arch):
flops = 0
for block_id, block in enumerate(arch):
if block == 'LayerChoice1' or block_id == 'LayerChoice23':
continue
for idx, choice in enumerate(arch[block]):
flops += self.flops_dict[block_id][idx] * (1 if choice else 0)
return flops + self.flops_fixed
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
# This dictionary is generated from calculating each operation of each layer to quickly search for layers.
# flops_op_dict[which_stage][which_operation] =
# (flops_of_operation_with_stride1, flops_of_operation_with_stride2)
flops_op_dict = {}
for i in range(5):
flops_op_dict[i] = {}
flops_op_dict[0][0] = (21.828704, 18.820752)
flops_op_dict[0][1] = (32.669328, 28.16048)
flops_op_dict[0][2] = (25.039968, 23.637648)
flops_op_dict[0][3] = (37.486224, 35.385824)
flops_op_dict[0][4] = (29.856864, 30.862992)
flops_op_dict[0][5] = (44.711568, 46.22384)
flops_op_dict[1][0] = (11.808656, 11.86712)
flops_op_dict[1][1] = (17.68624, 17.780848)
flops_op_dict[1][2] = (13.01288, 13.87416)
flops_op_dict[1][3] = (19.492576, 20.791408)
flops_op_dict[1][4] = (14.819216, 16.88472)
flops_op_dict[1][5] = (22.20208, 25.307248)
flops_op_dict[2][0] = (8.198, 10.99632)
flops_op_dict[2][1] = (12.292848, 16.5172)
flops_op_dict[2][2] = (8.69976, 11.99984)
flops_op_dict[2][3] = (13.045488, 18.02248)
flops_op_dict[2][4] = (9.4524, 13.50512)
flops_op_dict[2][5] = (14.174448, 20.2804)
flops_op_dict[3][0] = (12.006112, 15.61632)
flops_op_dict[3][1] = (18.028752, 23.46096)
flops_op_dict[3][2] = (13.009632, 16.820544)
flops_op_dict[3][3] = (19.534032, 25.267296)
flops_op_dict[3][4] = (14.514912, 18.62688)
flops_op_dict[3][5] = (21.791952, 27.9768)
flops_op_dict[4][0] = (11.307456, 15.292416)
flops_op_dict[4][1] = (17.007072, 23.1504)
flops_op_dict[4][2] = (11.608512, 15.894528)
flops_op_dict[4][3] = (17.458656, 24.053568)
flops_op_dict[4][4] = (12.060096, 16.797696)
flops_op_dict[4][5] = (18.136032, 25.40832)
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
def search_for_layer(flops_op_dict, arch_def, flops_minimum, flops_maximum):
sta_num = [1, 1, 1, 1, 1]
order = [2, 3, 4, 1, 0, 2, 3, 4, 1, 0]
limits = [3, 3, 3, 2, 2, 4, 4, 4, 4, 4]
size_factor = 224 // 32
base_min_flops = sum([flops_op_dict[i][0][0] for i in range(5)])
base_max_flops = sum([flops_op_dict[i][5][0] for i in range(5)])
if base_min_flops > flops_maximum:
while base_min_flops > flops_maximum and size_factor >= 2:
size_factor = size_factor - 1
flops_minimum = flops_minimum * (7. / size_factor)
flops_maximum = flops_maximum * (7. / size_factor)
if size_factor < 2:
return None, None, None
elif base_max_flops < flops_minimum:
cur_ptr = 0
while base_max_flops < flops_minimum and cur_ptr <= 9:
if sta_num[order[cur_ptr]] >= limits[cur_ptr]:
cur_ptr += 1
continue
base_max_flops = base_max_flops + \
flops_op_dict[order[cur_ptr]][5][1]
sta_num[order[cur_ptr]] += 1
if cur_ptr > 7 and base_max_flops < flops_minimum:
return None, None, None
cur_ptr = 0
while cur_ptr <= 9:
if sta_num[order[cur_ptr]] >= limits[cur_ptr]:
cur_ptr += 1
continue
base_max_flops = base_max_flops + flops_op_dict[order[cur_ptr]][5][1]
if base_max_flops <= flops_maximum:
sta_num[order[cur_ptr]] += 1
else:
break
arch_def = [item[:i] for i, item in zip([1] + sta_num + [1], arch_def)]
# print(arch_def)
return sta_num, arch_def, size_factor * 32
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
import sys
import argparse
import torch.nn as nn
from torch import optim as optim
from thop import profile, clever_format
from timm.utils import *
from lib.config import cfg
def get_path_acc(model, path, val_loader, args, val_iters=50):
prec1_m = AverageMeter()
prec5_m = AverageMeter()
with torch.no_grad():
for batch_idx, (input, target) in enumerate(val_loader):
if batch_idx >= val_iters:
break
if not args.prefetcher:
input = input.cuda()
target = target.cuda()
output = model(input, path)
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(
0,
reduce_factor,
reduce_factor).mean(
dim=2)
target = target[0:target.size(0):reduce_factor]
prec1, prec5 = accuracy(output, target, topk=(1, 5))
torch.cuda.synchronize()
prec1_m.update(prec1.item(), output.size(0))
prec5_m.update(prec5.item(), output.size(0))
return (prec1_m.avg, prec5_m.avg)
def get_logger(file_path):
""" Make python logger """
log_format = '%(asctime)s | %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
logger = logging.getLogger('')
formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
file_handler = logging.FileHandler(file_path)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def add_weight_decay_supernet(model, args, weight_decay=1e-5, skip_list=()):
decay = []
no_decay = []
meta_layer_no_decay = []
meta_layer_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(
".bias") or name in skip_list:
if 'meta_layer' in name:
meta_layer_no_decay.append(param)
else:
no_decay.append(param)
else:
if 'meta_layer' in name:
meta_layer_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0., 'lr': args.lr},
{'params': decay, 'weight_decay': weight_decay, 'lr': args.lr},
{'params': meta_layer_no_decay, 'weight_decay': 0., 'lr': args.meta_lr},
{'params': meta_layer_decay, 'weight_decay': 0, 'lr': args.meta_lr},
]
def create_optimizer_supernet(args, model, has_apex, filter_bias_and_bn=True):
opt_lower = args.opt.lower()
weight_decay = args.weight_decay
if 'adamw' in opt_lower or 'radam' in opt_lower:
weight_decay /= args.lr
if weight_decay and filter_bias_and_bn:
parameters = add_weight_decay_supernet(model, args, weight_decay)
weight_decay = 0.
else:
parameters = model.parameters()
if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(
), 'APEX and CUDA required for fused optimizers'
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if opt_lower == 'sgd' or opt_lower == 'nesterov':
optimizer = optim.SGD(
parameters,
momentum=args.momentum,
weight_decay=weight_decay,
nesterov=True)
elif opt_lower == 'momentum':
optimizer = optim.SGD(
parameters,
momentum=args.momentum,
weight_decay=weight_decay,
nesterov=False)
elif opt_lower == 'adam':
optimizer = optim.Adam(
parameters, weight_decay=weight_decay, eps=args.opt_eps)
else:
assert False and "Invalid optimizer"
raise ValueError
return optimizer
def convert_lowercase(cfg):
keys = cfg.keys()
lowercase_keys = [key.lower() for key in keys]
values = [cfg.get(key) for key in keys]
for lowercase_key, value in zip(lowercase_keys, values):
cfg.setdefault(lowercase_key, value)
return cfg
def parse_config_args(exp_name):
parser = argparse.ArgumentParser(description=exp_name)
parser.add_argument(
'--cfg',
type=str,
default='../experiments/workspace/retrain/retrain.yaml',
help='configuration of cream')
parser.add_argument('--local_rank', type=int, default=0,
help='local_rank')
args = parser.parse_args()
cfg.merge_from_file(args.cfg)
converted_cfg = convert_lowercase(cfg)
return args, converted_cfg
def get_model_flops_params(model, input_size=(1, 3, 224, 224)):
input = torch.randn(input_size)
macs, params = profile(deepcopy(model), inputs=(input,), verbose=False)
macs, params = clever_format([macs, params], "%.3f")
return macs, params
def cross_entropy_loss_with_soft_target(pred, soft_target):
logsoftmax = nn.LogSoftmax()
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
def create_supernet_scheduler(cfg, optimizer):
ITERS = cfg.EPOCHS * \
(1280000 / (cfg.NUM_GPU * cfg.DATASET.BATCH_SIZE))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: (
cfg.LR - step / ITERS) if step <= ITERS else 0, last_epoch=-1)
return lr_scheduler, cfg.EPOCHS
yacs
numpy==1.17
opencv-python==4.0.1.24
torchvision==0.2.1
thop
git+https://github.com/sovrasov/flops-counter.pytorch.git
pillow==6.1.0
torch==1.2
timm==0.1.20
tensorboardx==1.2
tensorboard
future
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
import os
import warnings
import datetime
import torch
import numpy as np
import torch.nn as nn
from torchscope import scope
from torch.utils.tensorboard import SummaryWriter
# import timm packages
from timm.optim import create_optimizer
from timm.models import resume_checkpoint
from timm.scheduler import create_scheduler
from timm.data import Dataset, create_loader
from timm.utils import ModelEma, update_summary
from timm.loss import LabelSmoothingCrossEntropy
# import apex as distributed package
try:
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
from apex.parallel import convert_syncbn_model
HAS_APEX = True
except ImportError:
from torch.nn.parallel import DistributedDataParallel as DDP
HAS_APEX = False
# import models and training functions
from lib.core.test import validate
from lib.core.retrain import train_epoch
from lib.models.structures.childnet import gen_childnet
from lib.utils.util import parse_config_args, get_logger, get_model_flops_params
from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def main():
args, cfg = parse_config_args('nni.cream.childnet')
# resolve logging
output_dir = os.path.join(cfg.SAVE_PATH,
"{}-{}".format(datetime.date.today().strftime('%m%d'),
cfg.MODEL))
if not os.path.exists(output_dir):
os.mkdir(output_dir)
if args.local_rank == 0:
logger = get_logger(os.path.join(output_dir, 'retrain.log'))
writer = SummaryWriter(os.path.join(output_dir, 'runs'))
else:
writer, logger = None, None
# retrain model selection
if cfg.NET.SELECTION == 481:
arch_list = [
[0], [
3, 4, 3, 1], [
3, 2, 3, 0], [
3, 3, 3, 1], [
3, 3, 3, 3], [
3, 3, 3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 224
elif cfg.NET.SELECTION == 43:
arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 96
elif cfg.NET.SELECTION == 14:
arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
cfg.DATASET.IMAGE_SIZE = 64
elif cfg.NET.SELECTION == 112:
arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 160
elif cfg.NET.SELECTION == 287:
arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 224
elif cfg.NET.SELECTION == 604:
arch_list = [
[0], [
3, 3, 2, 3, 3], [
3, 2, 3, 2, 3], [
3, 2, 3, 2, 3], [
3, 3, 2, 2, 3, 3], [
3, 3, 2, 3, 3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 224
elif cfg.NET.SELECTION == -1:
arch_list = cfg.NET.INPUT_ARCH
cfg.DATASET.IMAGE_SIZE = 224
else:
raise ValueError("Model Retrain Selection is not Supported!")
# define childnet architecture from arch_list
stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25']
choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25',
'ir_r1_k5_s2_e4_c40_se0.25',
'ir_r1_k3_s2_e6_c80_se0.25',
'ir_r1_k3_s1_e6_c96_se0.25',
'ir_r1_k3_s2_e6_c192_se0.25']
arch_def = [[stem[0]]] + [[choice_block_pool[idx]
for repeat_times in range(len(arch_list[idx + 1]))]
for idx in range(len(choice_block_pool))] + [[stem[1]]]
# generate childnet
model = gen_childnet(
arch_list,
arch_def,
num_classes=cfg.DATASET.NUM_CLASSES,
drop_rate=cfg.NET.DROPOUT_RATE,
global_pool=cfg.NET.GP)
# initialize training parameters
eval_metric = cfg.EVAL_METRICS
best_metric, best_epoch, saver = None, None, None
# initialize distributed parameters
distributed = cfg.NUM_GPU > 1
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
if args.local_rank == 0:
logger.info(
'Training on Process {} with {} GPUs.'.format(
args.local_rank, cfg.NUM_GPU))
# fix random seeds
torch.manual_seed(cfg.SEED)
torch.cuda.manual_seed_all(cfg.SEED)
np.random.seed(cfg.SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# get parameters and FLOPs of model
if args.local_rank == 0:
macs, params = get_model_flops_params(model, input_size=(
1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE))
logger.info(
'[Model-{}] Flops: {} Params: {}'.format(cfg.NET.SELECTION, macs, params))
# create optimizer
optimizer = create_optimizer(cfg, model)
model = model.cuda()
# optionally resume from a checkpoint
resume_state, resume_epoch = {}, None
if cfg.AUTO_RESUME:
resume_state, resume_epoch = resume_checkpoint(model, cfg.RESUME_PATH)
optimizer.load_state_dict(resume_state['optimizer'])
del resume_state
model_ema = None
if cfg.NET.EMA.USE:
model_ema = ModelEma(
model,
decay=cfg.NET.EMA.DECAY,
device='cpu' if cfg.NET.EMA.FORCE_CPU else '',
resume=cfg.RESUME_PATH if cfg.AUTO_RESUME else None)
if distributed:
if cfg.BATCHNORM.SYNC_BN:
try:
if HAS_APEX:
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
model)
if args.local_rank == 0:
logger.info(
'Converted model to use Synchronized BatchNorm.')
except Exception as e:
if args.local_rank == 0:
logger.error(
'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1 with exception {}'.format(e))
if HAS_APEX:
model = DDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
logger.info(
"Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
# can use device str in Torch >= 1.1
model = DDP(model, device_ids=[args.local_rank])
# imagenet train dataset
train_dir = os.path.join(cfg.DATA_DIR, 'train')
if not os.path.exists(train_dir) and args.local_rank == 0:
logger.error('Training folder does not exist at: {}'.format(train_dir))
exit(1)
dataset_train = Dataset(train_dir)
loader_train = create_loader(
dataset_train,
input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
batch_size=cfg.DATASET.BATCH_SIZE,
is_training=True,
color_jitter=cfg.AUGMENTATION.COLOR_JITTER,
auto_augment=cfg.AUGMENTATION.AA,
num_aug_splits=0,
crop_pct=DEFAULT_CROP_PCT,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=cfg.WORKERS,
distributed=distributed,
collate_fn=None,
pin_memory=cfg.DATASET.PIN_MEM,
interpolation='random',
re_mode=cfg.AUGMENTATION.RE_MODE,
re_prob=cfg.AUGMENTATION.RE_PROB
)
# imagenet validation dataset
eval_dir = os.path.join(cfg.DATA_DIR, 'val')
if not os.path.exists(eval_dir) and args.local_rank == 0:
logger.error(
'Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = Dataset(eval_dir)
loader_eval = create_loader(
dataset_eval,
input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE,
is_training=False,
interpolation=cfg.DATASET.INTERPOLATION,
crop_pct=DEFAULT_CROP_PCT,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=cfg.WORKERS,
distributed=distributed,
pin_memory=cfg.DATASET.PIN_MEM
)
# whether to use label smoothing
if cfg.AUGMENTATION.SMOOTHING > 0.:
train_loss_fn = LabelSmoothingCrossEntropy(
smoothing=cfg.AUGMENTATION.SMOOTHING).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = train_loss_fn
# create learning rate scheduler
lr_scheduler, num_epochs = create_scheduler(cfg, optimizer)
start_epoch = resume_epoch if resume_epoch is not None else 0
if start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
logger.info('Scheduled epochs: {}'.format(num_epochs))
try:
best_record, best_ep = 0, 0
for epoch in range(start_epoch, num_epochs):
if distributed:
loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch(
epoch,
model,
loader_train,
optimizer,
train_loss_fn,
cfg,
lr_scheduler=lr_scheduler,
saver=saver,
output_dir=output_dir,
model_ema=model_ema,
logger=logger,
writer=writer,
local_rank=args.local_rank)
eval_metrics = validate(
epoch,
model,
loader_eval,
validate_loss_fn,
cfg,
logger=logger,
writer=writer,
local_rank=args.local_rank)
if model_ema is not None and not cfg.NET.EMA.FORCE_CPU:
ema_eval_metrics = validate(
epoch,
model_ema.ema,
loader_eval,
validate_loss_fn,
cfg,
log_suffix='_EMA',
logger=logger,
writer=writer)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(epoch, train_metrics, eval_metrics, os.path.join(
output_dir, 'summary.csv'), write_header=best_metric is None)
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(
model, optimizer, cfg,
epoch=epoch, model_ema=model_ema, metric=save_metric)
if best_record < eval_metrics[eval_metric]:
best_record = eval_metrics[eval_metric]
best_ep = epoch
if args.local_rank == 0:
logger.info(
'*** Best metric: {0} (epoch {1})'.format(best_record, best_ep))
except KeyboardInterrupt:
pass
if best_metric is not None:
logger.info(
'*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
import os
import warnings
import datetime
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
# import timm packages
from timm.utils import ModelEma
from timm.models import resume_checkpoint
from timm.data import Dataset, create_loader
# import apex as distributed package
try:
from apex.parallel import convert_syncbn_model
from apex.parallel import DistributedDataParallel as DDP
HAS_APEX = True
except ImportError:
from torch.nn.parallel import DistributedDataParallel as DDP
HAS_APEX = False
# import models and training functions
from lib.core.test import validate
from lib.models.structures.childnet import gen_childnet
from lib.utils.util import parse_config_args, get_logger, get_model_flops_params
from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def main():
args, cfg = parse_config_args('child net testing')
# resolve logging
output_dir = os.path.join(cfg.SAVE_PATH,
"{}-{}".format(datetime.date.today().strftime('%m%d'),
cfg.MODEL))
if not os.path.exists(output_dir):
os.mkdir(output_dir)
if args.local_rank == 0:
logger = get_logger(os.path.join(output_dir, 'test.log'))
writer = SummaryWriter(os.path.join(output_dir, 'runs'))
else:
writer, logger = None, None
# retrain model selection
if cfg.NET.SELECTION == 481:
arch_list = [
[0], [
3, 4, 3, 1], [
3, 2, 3, 0], [
3, 3, 3, 1], [
3, 3, 3, 3], [
3, 3, 3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 224
elif cfg.NET.SELECTION == 43:
arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 96
elif cfg.NET.SELECTION == 14:
arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
cfg.DATASET.IMAGE_SIZE = 64
elif cfg.NET.SELECTION == 112:
arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 160
elif cfg.NET.SELECTION == 287:
arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 224
elif cfg.NET.SELECTION == 604:
arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3],
[3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 224
else:
raise ValueError("Model Test Selection is not Supported!")
# define childnet architecture from arch_list
stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25']
choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25',
'ir_r1_k5_s2_e4_c40_se0.25',
'ir_r1_k3_s2_e6_c80_se0.25',
'ir_r1_k3_s1_e6_c96_se0.25',
'ir_r1_k3_s2_e6_c192_se0.25']
arch_def = [[stem[0]]] + [[choice_block_pool[idx]
for repeat_times in range(len(arch_list[idx + 1]))]
for idx in range(len(choice_block_pool))] + [[stem[1]]]
# generate childnet
model = gen_childnet(
arch_list,
arch_def,
num_classes=cfg.DATASET.NUM_CLASSES,
drop_rate=cfg.NET.DROPOUT_RATE,
global_pool=cfg.NET.GP)
if args.local_rank == 0:
macs, params = get_model_flops_params(model, input_size=(
1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE))
logger.info(
'[Model-{}] Flops: {} Params: {}'.format(cfg.NET.SELECTION, macs, params))
# initialize distributed parameters
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
if args.local_rank == 0:
logger.info(
"Training on Process {} with {} GPUs.".format(
args.local_rank, cfg.NUM_GPU))
# resume model from checkpoint
assert cfg.AUTO_RESUME is True and os.path.exists(cfg.RESUME_PATH)
_, __ = resume_checkpoint(model, cfg.RESUME_PATH)
model = model.cuda()
model_ema = None
if cfg.NET.EMA.USE:
# Important to create EMA model after cuda(), DP wrapper, and AMP but
# before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=cfg.NET.EMA.DECAY,
device='cpu' if cfg.NET.EMA.FORCE_CPU else '',
resume=cfg.RESUME_PATH)
# imagenet validation dataset
eval_dir = os.path.join(cfg.DATA_DIR, 'val')
if not os.path.exists(eval_dir) and args.local_rank == 0:
logger.error(
'Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = Dataset(eval_dir)
loader_eval = create_loader(
dataset_eval,
input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE,
is_training=False,
num_workers=cfg.WORKERS,
distributed=True,
pin_memory=cfg.DATASET.PIN_MEM,
crop_pct=DEFAULT_CROP_PCT,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
)
# only test accuracy of model-EMA
validate_loss_fn = nn.CrossEntropyLoss().cuda()
validate(0, model_ema.ema, loader_eval, validate_loss_fn, cfg,
log_suffix='_EMA', logger=logger,
writer=writer, local_rank=args.local_rank)
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
import os
import sys
import datetime
import torch
import numpy as np
import torch.nn as nn
# import timm packages
from timm.loss import LabelSmoothingCrossEntropy
from timm.data import Dataset, create_loader
from timm.models import resume_checkpoint
# import apex as distributed package
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.parallel import convert_syncbn_model
USE_APEX = True
except ImportError:
from torch.nn.parallel import DistributedDataParallel as DDP
USE_APEX = False
# import models and training functions
from lib.utils.flops_table import FlopsEst
from lib.models.structures.supernet import gen_supernet
from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from lib.utils.util import parse_config_args, get_logger, \
create_optimizer_supernet, create_supernet_scheduler
from nni.nas.pytorch.callbacks import LRSchedulerCallback
from nni.nas.pytorch.callbacks import ModelCheckpoint
from nni.algorithms.nas.pytorch.cream import CreamSupernetTrainer
from nni.algorithms.nas.pytorch.random import RandomMutator
def main():
args, cfg = parse_config_args('nni.cream.supernet')
# resolve logging
output_dir = os.path.join(cfg.SAVE_PATH,
"{}-{}".format(datetime.date.today().strftime('%m%d'),
cfg.MODEL))
if not os.path.exists(output_dir):
os.mkdir(output_dir)
if args.local_rank == 0:
logger = get_logger(os.path.join(output_dir, "train.log"))
else:
logger = None
# initialize distributed parameters
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
if args.local_rank == 0:
logger.info(
'Training on Process %d with %d GPUs.',
args.local_rank, cfg.NUM_GPU)
# fix random seeds
torch.manual_seed(cfg.SEED)
torch.cuda.manual_seed_all(cfg.SEED)
np.random.seed(cfg.SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# generate supernet
model, sta_num, resolution = gen_supernet(
flops_minimum=cfg.SUPERNET.FLOPS_MINIMUM,
flops_maximum=cfg.SUPERNET.FLOPS_MAXIMUM,
num_classes=cfg.DATASET.NUM_CLASSES,
drop_rate=cfg.NET.DROPOUT_RATE,
global_pool=cfg.NET.GP,
resunit=cfg.SUPERNET.RESUNIT,
dil_conv=cfg.SUPERNET.DIL_CONV,
slice=cfg.SUPERNET.SLICE,
verbose=cfg.VERBOSE,
logger=logger)
# number of choice blocks in supernet
choice_num = len(model.blocks[7])
if args.local_rank == 0:
logger.info('Supernet created, param count: %d', (
sum([m.numel() for m in model.parameters()])))
logger.info('resolution: %d', (resolution))
logger.info('choice number: %d', (choice_num))
# initialize flops look-up table
model_est = FlopsEst(model)
flops_dict, flops_fixed = model_est.flops_dict, model_est.flops_fixed
# optionally resume from a checkpoint
optimizer_state = None
resume_epoch = None
if cfg.AUTO_RESUME:
optimizer_state, resume_epoch = resume_checkpoint(
model, cfg.RESUME_PATH)
# create optimizer and resume from checkpoint
optimizer = create_optimizer_supernet(cfg, model, USE_APEX)
if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state['optimizer'])
model = model.cuda()
# convert model to distributed mode
if cfg.BATCHNORM.SYNC_BN:
try:
if USE_APEX:
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
logger.info('Converted model to use Synchronized BatchNorm.')
except Exception as exception:
logger.info(
'Failed to enable Synchronized BatchNorm. '
'Install Apex or Torch >= 1.1 with Exception %s', exception)
if USE_APEX:
model = DDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
logger.info(
"Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
# can use device str in Torch >= 1.1
model = DDP(model, device_ids=[args.local_rank])
# create learning rate scheduler
lr_scheduler, num_epochs = create_supernet_scheduler(cfg, optimizer)
start_epoch = resume_epoch if resume_epoch is not None else 0
if start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
logger.info('Scheduled epochs: %d', num_epochs)
# imagenet train dataset
train_dir = os.path.join(cfg.DATA_DIR, 'train')
if not os.path.exists(train_dir):
logger.info('Training folder does not exist at: %s', train_dir)
sys.exit()
dataset_train = Dataset(train_dir)
loader_train = create_loader(
dataset_train,
input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
batch_size=cfg.DATASET.BATCH_SIZE,
is_training=True,
use_prefetcher=True,
re_prob=cfg.AUGMENTATION.RE_PROB,
re_mode=cfg.AUGMENTATION.RE_MODE,
color_jitter=cfg.AUGMENTATION.COLOR_JITTER,
interpolation='random',
num_workers=cfg.WORKERS,
distributed=True,
collate_fn=None,
crop_pct=DEFAULT_CROP_PCT,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
)
# imagenet validation dataset
eval_dir = os.path.join(cfg.DATA_DIR, 'val')
if not os.path.isdir(eval_dir):
logger.info('Validation folder does not exist at: %s', eval_dir)
sys.exit()
dataset_eval = Dataset(eval_dir)
loader_eval = create_loader(
dataset_eval,
input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
batch_size=4 * cfg.DATASET.BATCH_SIZE,
is_training=False,
use_prefetcher=True,
num_workers=cfg.WORKERS,
distributed=True,
crop_pct=DEFAULT_CROP_PCT,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
interpolation=cfg.DATASET.INTERPOLATION
)
# whether to use label smoothing
if cfg.AUGMENTATION.SMOOTHING > 0.:
train_loss_fn = LabelSmoothingCrossEntropy(
smoothing=cfg.AUGMENTATION.SMOOTHING).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = train_loss_fn
mutator = RandomMutator(model)
trainer = CreamSupernetTrainer(model, train_loss_fn, validate_loss_fn,
optimizer, num_epochs, loader_train, loader_eval,
mutator=mutator, batch_size=cfg.DATASET.BATCH_SIZE,
log_frequency=cfg.LOG_INTERVAL,
meta_sta_epoch=cfg.SUPERNET.META_STA_EPOCH,
update_iter=cfg.SUPERNET.UPDATE_ITER,
slices=cfg.SUPERNET.SLICE,
pool_size=cfg.SUPERNET.POOL_SIZE,
pick_method=cfg.SUPERNET.PICK_METHOD,
choice_num=choice_num, sta_num=sta_num, acc_gap=cfg.ACC_GAP,
flops_dict=flops_dict, flops_fixed=flops_fixed, local_rank=args.local_rank,
callbacks=[LRSchedulerCallback(lr_scheduler),
ModelCheckpoint(output_dir)])
trainer.train()
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .trainer import CreamSupernetTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import torch
import logging
from copy import deepcopy
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .utils import accuracy, reduce_metrics
logger = logging.getLogger(__name__)
class CreamSupernetTrainer(Trainer):
"""
This trainer trains a supernet and output prioritized architectures that can be used for other tasks.
Parameters
----------
model : nn.Module
Model with mutables.
loss : callable
Called with logits and targets. Returns a loss tensor.
val_loss : callable
Called with logits and targets for validation only. Returns a loss tensor.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterablez
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
valid_loader : iterablez
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
mutator : Mutator
A mutator object that has been initialized with the model.
batch_size : int
Batch size.
log_frequency : int
Number of mini-batches to log metrics.
meta_sta_epoch : int
start epoch of using meta matching network to pick teacher architecture
update_iter : int
interval of updating meta matching networks
slices : int
batch size of mini training data in the process of training meta matching network
pool_size : int
board size
pick_method : basestring
how to pick teacher network
choice_num : int
number of operations in supernet
sta_num : int
layer number of each stage in supernet (5 stage in supernet)
acc_gap : int
maximum accuracy improvement to omit the limitation of flops
flops_dict : Dict
dictionary of each layer's operations in supernet
flops_fixed : int
flops of fixed part in supernet
local_rank : int
index of current rank
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
def __init__(self, model, loss, val_loss,
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, log_frequency=None,
meta_sta_epoch=20, update_iter=200, slices=2,
pool_size=10, pick_method='meta', choice_num=6,
sta_num=(4, 4, 4, 4, 4), acc_gap=5,
flops_dict=None, flops_fixed=0, local_rank=0, callbacks=None):
assert torch.cuda.is_available()
super(CreamSupernetTrainer, self).__init__(model, mutator, loss, None,
optimizer, num_epochs, None, None,
batch_size, None, None, log_frequency, callbacks)
self.model = model
self.loss = loss
self.val_loss = val_loss
self.train_loader = train_loader
self.valid_loader = valid_loader
self.log_frequency = log_frequency
self.batch_size = batch_size
self.optimizer = optimizer
self.model = model
self.loss = loss
self.num_epochs = num_epochs
self.meta_sta_epoch = meta_sta_epoch
self.update_iter = update_iter
self.slices = slices
self.pick_method = pick_method
self.pool_size = pool_size
self.local_rank = local_rank
self.choice_num = choice_num
self.sta_num = sta_num
self.acc_gap = acc_gap
self.flops_dict = flops_dict
self.flops_fixed = flops_fixed
self.current_student_arch = None
self.current_teacher_arch = None
self.main_proc = (local_rank == 0)
self.current_epoch = 0
self.prioritized_board = []
# size of prioritized board
def _board_size(self):
return len(self.prioritized_board)
# select teacher architecture according to the logit difference
def _select_teacher(self):
self._replace_mutator_cand(self.current_student_arch)
if self.pick_method == 'top1':
meta_value, teacher_cand = 0.5, sorted(
self.prioritized_board, reverse=True)[0][3]
elif self.pick_method == 'meta':
meta_value, cand_idx, teacher_cand = -1000000000, -1, None
for now_idx, item in enumerate(self.prioritized_board):
inputx = item[4]
output = torch.nn.functional.softmax(self.model(inputx), dim=1)
weight = self.model.module.forward_meta(output - item[5])
if weight > meta_value:
meta_value = weight
cand_idx = now_idx
teacher_cand = self.prioritized_board[cand_idx][3]
assert teacher_cand is not None
meta_value = torch.nn.functional.sigmoid(-weight)
else:
raise ValueError('Method Not supported')
return meta_value, teacher_cand
# check whether to update prioritized board
def _isUpdateBoard(self, prec1, flops):
if self.current_epoch <= self.meta_sta_epoch:
return False
if len(self.prioritized_board) < self.pool_size:
return True
if prec1 > self.prioritized_board[-1][1] + self.acc_gap:
return True
if prec1 > self.prioritized_board[-1][1] and flops < self.prioritized_board[-1][2]:
return True
return False
# update prioritized board
def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flops):
if self._isUpdateBoard(prec1, flops):
val_prec1 = prec1
training_data = deepcopy(inputs[:self.slices].detach())
if len(self.prioritized_board) == 0:
features = deepcopy(outputs[:self.slices].detach())
else:
features = deepcopy(
teacher_output[:self.slices].detach())
self.prioritized_board.append(
(val_prec1,
prec1,
flops,
self.current_teacher_arch,
training_data,
torch.nn.functional.softmax(
features,
dim=1)))
self.prioritized_board = sorted(
self.prioritized_board, reverse=True)
if len(self.prioritized_board) > self.pool_size:
self.prioritized_board = sorted(
self.prioritized_board, reverse=True)
del self.prioritized_board[-1]
# only update student network weights
def _update_student_weights_only(self, grad_1):
for weight, grad_item in zip(
self.model.module.rand_parameters(self.current_student_arch), grad_1):
weight.grad = grad_item
torch.nn.utils.clip_grad_norm_(
self.model.module.rand_parameters(self.current_student_arch), 1)
self.optimizer.step()
for weight, grad_item in zip(
self.model.module.rand_parameters(self.current_student_arch), grad_1):
del weight.grad
# only update meta networks weights
def _update_meta_weights_only(self, teacher_cand, grad_teacher):
for weight, grad_item in zip(self.model.module.rand_parameters(
teacher_cand, self.pick_method == 'meta'), grad_teacher):
weight.grad = grad_item
# clip gradients
torch.nn.utils.clip_grad_norm_(
self.model.module.rand_parameters(
self.current_student_arch, self.pick_method == 'meta'), 1)
self.optimizer.step()
for weight, grad_item in zip(self.model.module.rand_parameters(
teacher_cand, self.pick_method == 'meta'), grad_teacher):
del weight.grad
# simulate sgd updating
def _simulate_sgd_update(self, w, g, optimizer):
return g * optimizer.param_groups[-1]['lr'] + w
# split training images into several slices
def _get_minibatch_input(self, input):
slice = self.slices
x = deepcopy(input[:slice].clone().detach())
return x
# calculate 1st gradient of student architectures
def _calculate_1st_gradient(self, kd_loss):
self.optimizer.zero_grad()
grad = torch.autograd.grad(
kd_loss,
self.model.module.rand_parameters(self.current_student_arch),
create_graph=True)
return grad
# calculate 2nd gradient of meta networks
def _calculate_2nd_gradient(self, validation_loss, teacher_cand, students_weight):
self.optimizer.zero_grad()
grad_student_val = torch.autograd.grad(
validation_loss,
self.model.module.rand_parameters(self.random_cand),
retain_graph=True)
grad_teacher = torch.autograd.grad(
students_weight[0],
self.model.module.rand_parameters(
teacher_cand,
self.pick_method == 'meta'),
grad_outputs=grad_student_val)
return grad_teacher
# forward training data
def _forward_training(self, x, meta_value):
self._replace_mutator_cand(self.current_student_arch)
output = self.model(x)
with torch.no_grad():
self._replace_mutator_cand(self.current_teacher_arch)
teacher_output = self.model(x)
soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
kd_loss = meta_value * \
self._cross_entropy_loss_with_soft_target(output, soft_label)
return kd_loss
# calculate soft target loss
def _cross_entropy_loss_with_soft_target(self, pred, soft_target):
logsoftmax = torch.nn.LogSoftmax()
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
# forward validation data
def _forward_validation(self, input, target):
slice = self.slices
x = input[slice:slice * 2].clone()
self._replace_mutator_cand(self.current_student_arch)
output_2 = self.model(x)
validation_loss = self.loss(output_2, target[slice:slice * 2])
return validation_loss
def _isUpdateMeta(self, batch_idx):
isUpdate = True
isUpdate &= (self.current_epoch > self.meta_sta_epoch)
isUpdate &= (batch_idx > 0)
isUpdate &= (batch_idx % self.update_iter == 0)
isUpdate &= (self._board_size() > 0)
return isUpdate
def _replace_mutator_cand(self, cand):
self.mutator._cache = cand
# update meta matching networks
def _run_update(self, input, target, batch_idx):
if self._isUpdateMeta(batch_idx):
x = self._get_minibatch_input(input)
meta_value, teacher_cand = self._select_teacher()
kd_loss = self._forward_training(x, meta_value)
# calculate 1st gradient
grad_1st = self._calculate_1st_gradient(kd_loss)
# simulate updated student weights
students_weight = [
self._simulate_sgd_update(
p, grad_item, self.optimizer) for p, grad_item in zip(
self.model.module.rand_parameters(self.current_student_arch), grad_1st)]
# update student weights
self._update_student_weights_only(grad_1st)
validation_loss = self._forward_validation(input, target)
# calculate 2nd gradient
grad_teacher = self._calculate_2nd_gradient(validation_loss, teacher_cand, students_weight)
# update meta matching networks
self._update_meta_weights_only(teacher_cand, grad_teacher)
# delete internal variants
del grad_teacher, grad_1st, x, validation_loss, kd_loss, students_weight
def _get_cand_flops(self, cand):
flops = 0
for block_id, block in enumerate(cand):
if block == 'LayerChoice1' or block_id == 'LayerChoice23':
continue
for idx, choice in enumerate(cand[block]):
flops += self.flops_dict[block_id][idx] * (1 if choice else 0)
return flops + self.flops_fixed
def train_one_epoch(self, epoch):
self.current_epoch = epoch
meters = AverageMeterGroup()
self.steps_per_epoch = len(self.train_loader)
for step, (input_data, target) in enumerate(self.train_loader):
self.mutator.reset()
self.current_student_arch = self.mutator._cache
input_data, target = input_data.cuda(), target.cuda()
# calculate flops of current architecture
cand_flops = self._get_cand_flops(self.mutator._cache)
# update meta matching network
self._run_update(input_data, target, step)
if self._board_size() > 0:
# select teacher architecture
meta_value, teacher_cand = self._select_teacher()
self.current_teacher_arch = teacher_cand
# forward supernet
if self._board_size() == 0 or epoch <= self.meta_sta_epoch:
self._replace_mutator_cand(self.current_student_arch)
output = self.model(input_data)
loss = self.loss(output, target)
kd_loss, teacher_output, teacher_cand = None, None, None
else:
self._replace_mutator_cand(self.current_student_arch)
output = self.model(input_data)
gt_loss = self.loss(output, target)
with torch.no_grad():
self._replace_mutator_cand(self.current_teacher_arch)
teacher_output = self.model(input_data).detach()
soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
kd_loss = self._cross_entropy_loss_with_soft_target(output, soft_label)
loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2
# update network
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# update metrics
prec1, prec5 = accuracy(output, target, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = reduce_metrics(metrics)
meters.update(metrics)
# update prioritized board
self._update_prioritized_board(input_data, teacher_output, output, metrics['prec1'], cand_flops)
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs,
step + 1, len(self.train_loader), meters)
if self.main_proc and self.num_epochs == epoch + 1:
for idx, i in enumerate(self.best_children_pool):
logger.info("No.%s %s", idx, i[:4])
def validate_one_epoch(self, epoch):
self.model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
self.mutator.reset()
logits = self.model(x)
loss = self.val_loss(logits, y)
prec1, prec5 = self.accuracy(logits, y, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = self.reduce_metrics(metrics, self.distributed)
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.valid_loader), meters)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import torch.distributed as dist
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(1.0 / batch_size))
return res
def reduce_metrics(metrics):
return {k: reduce_tensor(v).item() for k, v in metrics.items()}
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= float(os.environ["WORLD_SIZE"])
return rt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment