"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "1bba7124ed348eff22367ac8ae6a5cf027b4d7de"
Commit 1011377c authored by qianyj's avatar qianyj
Browse files

the source code of NNI for DCU

parent abc22158
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
from lib.utils.builder_util import *
from lib.utils.search_structure_supernet import *
from lib.models.builders.build_supernet import *
from lib.utils.op_by_layer_dict import flops_op_dict
from timm.models.layers import SelectAdaptivePool2d
from timm.models.layers.activations import hard_sigmoid
class SuperNet(nn.Module):
def __init__(
self,
block_args,
choices,
num_classes=1000,
in_chans=3,
stem_size=16,
num_features=1280,
head_bias=True,
channel_multiplier=1.0,
pad_type='',
act_layer=nn.ReLU,
drop_rate=0.,
drop_path_rate=0.,
slice=4,
se_kwargs=None,
norm_layer=nn.BatchNorm2d,
logger=None,
norm_kwargs=None,
global_pool='avg',
resunit=False,
dil_conv=False,
verbose=False):
super(SuperNet, self).__init__()
self.num_classes = num_classes
self.num_features = num_features
self.drop_rate = drop_rate
self._in_chs = in_chans
self.logger = logger
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = create_conv2d(
self._in_chs, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self._in_chs = stem_size
# Middle stages (IR/ER/DS Blocks)
builder = SuperNetBuilder(
choices,
channel_multiplier,
8,
None,
32,
pad_type,
act_layer,
se_kwargs,
norm_layer,
norm_kwargs,
drop_path_rate,
verbose=verbose,
resunit=resunit,
dil_conv=dil_conv,
logger=self.logger)
blocks = builder(self._in_chs, block_args)
self.blocks = nn.Sequential(*blocks)
self._in_chs = builder.in_chs
# Head + Pooling
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.conv_head = create_conv2d(
self._in_chs,
self.num_features,
1,
padding=pad_type,
bias=head_bias)
self.act2 = act_layer(inplace=True)
# Classifier
self.classifier = nn.Linear(
self.num_features *
self.global_pool.feat_mult(),
self.num_classes)
self.meta_layer = nn.Linear(self.num_classes * slice, 1)
efficientnet_init_weights(self)
def get_classifier(self):
return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
self.classifier = nn.Linear(
self.num_features * self.global_pool.feat_mult(),
num_classes) if self.num_classes else None
def forward_features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
x = self.blocks(x)
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = x.flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
def forward_meta(self, features):
return self.meta_layer(features.view(1, -1))
def rand_parameters(self, architecture, meta=False):
for name, param in self.named_parameters(recurse=True):
if 'meta' in name and meta:
yield param
elif 'blocks' not in name and 'meta' not in name and (not meta):
yield param
if not meta:
for choice_blocks, choice_name in zip(self.blocks, architecture):
choice_sample = architecture[choice_name]
for block, arch in zip(choice_blocks, choice_sample):
if not arch:
continue
for name, param in block.named_parameters(recurse=True):
yield param
class Classifier(nn.Module):
def __init__(self, num_classes=1000):
super(Classifier, self).__init__()
self.classifier = nn.Linear(num_classes, num_classes)
def forward(self, x):
return self.classifier(x)
def gen_supernet(flops_minimum=0, flops_maximum=600, **kwargs):
choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]}
num_features = 1280
# act_layer = HardSwish
act_layer = Swish
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_se0.25'],
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25',
'ir_r1_k3_s1_e4_c24_se0.25'],
# stage 2, 56x56 in
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s1_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25',
'ir_r1_k5_s2_e4_c40_se0.25'],
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25',
'ir_r2_k3_s1_e4_c80_se0.25'],
# stage 4, 14x14in
['ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
'ir_r1_k3_s1_e6_c96_se0.25'],
# stage 5, 14x14in
['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25',
'ir_r1_k5_s2_e6_c192_se0.25'],
# stage 6, 7x7 in
['cn_r1_k1_s1_c320_se0.25'],
]
sta_num, arch_def, resolution = search_for_layer(
flops_op_dict, arch_def, flops_minimum, flops_maximum)
if sta_num is None or arch_def is None or resolution is None:
raise ValueError('Invalid FLOPs Settings')
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
choices=choices,
num_features=num_features,
stem_size=16,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=act_layer,
se_kwargs=dict(
act_layer=nn.ReLU,
gate_fn=hard_sigmoid,
reduce_mid=True,
divisor=8),
**kwargs,
)
model = SuperNet(**model_kwargs)
return model, sta_num, resolution
import re
import math
import torch.nn as nn
from copy import deepcopy
from timm.utils import *
from timm.models.layers.activations import Swish
from timm.models.layers import CondConv2d, get_condconv_initializer
def parse_ksize(ss):
if ss.isdigit():
return int(ss)
else:
return [int(k) for k in ss.split('.')]
def decode_arch_def(
arch_def,
depth_multiplier=1.0,
depth_trunc='ceil',
experts_multiplier=1):
arch_args = []
for stack_idx, block_strings in enumerate(arch_def):
assert isinstance(block_strings, list)
stack_args = []
repeats = []
for block_str in block_strings:
assert isinstance(block_str, str)
ba, rep = decode_block_str(block_str)
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
ba['num_experts'] *= experts_multiplier
stack_args.append(ba)
repeats.append(rep)
arch_args.append(
scale_stage_depth(
stack_args,
repeats,
depth_multiplier,
depth_trunc))
return arch_args
def modify_block_args(block_args, kernel_size, exp_ratio):
block_type = block_args['block_type']
if block_type == 'cn':
block_args['kernel_size'] = kernel_size
elif block_type == 'er':
block_args['exp_kernel_size'] = kernel_size
else:
block_args['dw_kernel_size'] = kernel_size
if block_type == 'ir' or block_type == 'er':
block_args['exp_ratio'] = exp_ratio
return block_args
def decode_block_str(block_str):
""" Decode block definition string
Gets a list of block arg (dicts) through a string notation of arguments.
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
All args can exist in any order with the exception of the leading string which
is assumed to indicate the block type.
leading string - block type (
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
r - number of repeat blocks,
k - kernel size,
s - strides (1-9),
e - expansion ratio,
c - output channels,
se - squeeze/excitation ratio
n - activation fn ('re', 'r6', 'hs', or 'sw')
Args:
block_str: a string representation of block arguments.
Returns:
A list of block args (dicts)
Raises:
ValueError: if the string def not properly specified (TODO)
"""
assert isinstance(block_str, str)
ops = block_str.split('_')
block_type = ops[0] # take the block type off the front
ops = ops[1:]
options = {}
noskip = False
for op in ops:
# string options being checked on individual basis, combine if they
# grow
if op == 'noskip':
noskip = True
elif op.startswith('n'):
# activation fn
key = op[0]
v = op[1:]
if v == 're':
value = nn.ReLU
elif v == 'r6':
value = nn.ReLU6
elif v == 'sw':
value = Swish
else:
continue
options[key] = value
else:
# all numeric options
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# if act_layer is None, the model default (passed to model init) will be
# used
act_layer = options['n'] if 'n' in options else None
exp_kernel_size = parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = parse_ksize(options['p']) if 'p' in options else 1
# FIXME hack to deal with in_chs issue in TPU def
fake_in_chs = int(options['fc']) if 'fc' in options else 0
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
if block_type == 'ir':
block_args = dict(
block_type=block_type,
dw_kernel_size=parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
noskip=noskip,
)
if 'cc' in options:
block_args['num_experts'] = int(options['cc'])
elif block_type == 'ds' or block_type == 'dsa':
block_args = dict(
block_type=block_type,
dw_kernel_size=parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
)
elif block_type == 'cn':
block_args = dict(
block_type=block_type,
kernel_size=int(options['k']),
out_chs=int(options['c']),
stride=int(options['s']),
act_layer=act_layer,
)
else:
assert False, 'Unknown block type (%s)' % block_type
return block_args, num_repeat
def scale_stage_depth(
stack_args,
repeats,
depth_multiplier=1.0,
depth_trunc='ceil'):
""" Per-stage depth scaling
Scales the block repeats in each stage. This depth scaling impl maintains
compatibility with the EfficientNet scaling method, while allowing sensible
scaling for other models that may have multiple block arg definitions in each stage.
"""
# We scale the total repeat count for each stage, there may be multiple
# block arg defs per stage so we need to sum.
num_repeat = sum(repeats)
if depth_trunc == 'round':
# Truncating to int by rounding allows stages with few repeats to remain
# proportionally smaller for longer. This is a good choice when stage definitions
# include single repeat stages that we'd prefer to keep that way as
# long as possible
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
else:
# The default for EfficientNet truncates repeats to int via 'ceil'.
# Any multiplier > 1.0 will result in an increased depth for every
# stage.
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
# Proportionally distribute repeat count scaling to each block definition in the stage.
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
# The first block makes less sense to repeat in most of the arch
# definitions.
repeats_scaled = []
for r in repeats[::-1]:
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
repeats_scaled.append(rs)
num_repeat -= r
num_repeat_scaled -= rs
repeats_scaled = repeats_scaled[::-1]
# Apply the calculated scaling to each block arg in the stage
sa_scaled = []
for ba, rep in zip(stack_args, repeats_scaled):
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
return sa_scaled
def init_weight_goog(m, n='', fix_group_fanout=True, last_bn=None):
""" Weight initialization as per Tensorflow official implementations.
Args:
m (nn.Module): module to init
n (str): module name
fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
"""
if isinstance(m, CondConv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
init_weight_fn = get_condconv_initializer(lambda w: w.data.normal_(
0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
init_weight_fn(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
if n in last_bn:
m.weight.data.zero_()
m.bias.data.zero_()
else:
m.weight.data.fill_(1.0)
m.bias.data.zero_()
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
fan_out = m.weight.size(0) # fan-out
fan_in = 0
if 'routing_fn' in n:
fan_in = m.weight.size(1)
init_range = 1.0 / math.sqrt(fan_in + fan_out)
m.weight.data.uniform_(-init_range, init_range)
m.bias.data.zero_()
def efficientnet_init_weights(
model: nn.Module,
init_fn=None,
zero_gamma=False):
last_bn = []
if zero_gamma:
prev_n = ''
for n, m in model.named_modules():
if isinstance(m, nn.BatchNorm2d):
if ''.join(
prev_n.split('.')[
:-
1]) != ''.join(
n.split('.')[
:-
1]):
last_bn.append(prev_n)
prev_n = n
last_bn.append(prev_n)
init_fn = init_fn or init_weight_goog
for n, m in model.named_modules():
init_fn(m, n, last_bn=last_bn)
init_fn(m, n, last_bn=last_bn)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
import torch
from ptflops import get_model_complexity_info
class FlopsEst(object):
def __init__(self, model, input_shape=(2, 3, 224, 224), device='cpu'):
self.block_num = len(model.blocks)
self.choice_num = len(model.blocks[0])
self.flops_dict = {}
self.params_dict = {}
if device == 'cpu':
model = model.cpu()
else:
model = model.cuda()
self.params_fixed = 0
self.flops_fixed = 0
input = torch.randn(input_shape)
flops, params = get_model_complexity_info(
model.conv_stem, (3, 224, 224), as_strings=False, print_per_layer_stat=False)
self.params_fixed += params / 1e6
self.flops_fixed += flops / 1e6
input = model.conv_stem(input)
for block_id, block in enumerate(model.blocks):
self.flops_dict[block_id] = {}
self.params_dict[block_id] = {}
for module_id, module in enumerate(block):
flops, params = get_model_complexity_info(module, tuple(
input.shape[1:]), as_strings=False, print_per_layer_stat=False)
# Flops(M)
self.flops_dict[block_id][module_id] = flops / 1e6
# Params(M)
self.params_dict[block_id][module_id] = params / 1e6
input = module(input)
# conv_last
flops, params = get_model_complexity_info(model.global_pool, tuple(
input.shape[1:]), as_strings=False, print_per_layer_stat=False)
self.params_fixed += params / 1e6
self.flops_fixed += flops / 1e6
input = model.global_pool(input)
# globalpool
flops, params = get_model_complexity_info(model.conv_head, tuple(
input.shape[1:]), as_strings=False, print_per_layer_stat=False)
self.params_fixed += params / 1e6
self.flops_fixed += flops / 1e6
# return params (M)
def get_params(self, arch):
params = 0
for block_id, block in enumerate(arch):
if block == -1:
continue
params += self.params_dict[block_id][block]
return params + self.params_fixed
# return flops (M)
def get_flops(self, arch):
flops = 0
for block_id, block in enumerate(arch):
if block == 'LayerChoice1' or block_id == 'LayerChoice23':
continue
for idx, choice in enumerate(arch[block]):
flops += self.flops_dict[block_id][idx] * (1 if choice else 0)
return flops + self.flops_fixed
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
# This dictionary is generated from calculating each operation of each layer to quickly search for layers.
# flops_op_dict[which_stage][which_operation] =
# (flops_of_operation_with_stride1, flops_of_operation_with_stride2)
flops_op_dict = {}
for i in range(5):
flops_op_dict[i] = {}
flops_op_dict[0][0] = (21.828704, 18.820752)
flops_op_dict[0][1] = (32.669328, 28.16048)
flops_op_dict[0][2] = (25.039968, 23.637648)
flops_op_dict[0][3] = (37.486224, 35.385824)
flops_op_dict[0][4] = (29.856864, 30.862992)
flops_op_dict[0][5] = (44.711568, 46.22384)
flops_op_dict[1][0] = (11.808656, 11.86712)
flops_op_dict[1][1] = (17.68624, 17.780848)
flops_op_dict[1][2] = (13.01288, 13.87416)
flops_op_dict[1][3] = (19.492576, 20.791408)
flops_op_dict[1][4] = (14.819216, 16.88472)
flops_op_dict[1][5] = (22.20208, 25.307248)
flops_op_dict[2][0] = (8.198, 10.99632)
flops_op_dict[2][1] = (12.292848, 16.5172)
flops_op_dict[2][2] = (8.69976, 11.99984)
flops_op_dict[2][3] = (13.045488, 18.02248)
flops_op_dict[2][4] = (9.4524, 13.50512)
flops_op_dict[2][5] = (14.174448, 20.2804)
flops_op_dict[3][0] = (12.006112, 15.61632)
flops_op_dict[3][1] = (18.028752, 23.46096)
flops_op_dict[3][2] = (13.009632, 16.820544)
flops_op_dict[3][3] = (19.534032, 25.267296)
flops_op_dict[3][4] = (14.514912, 18.62688)
flops_op_dict[3][5] = (21.791952, 27.9768)
flops_op_dict[4][0] = (11.307456, 15.292416)
flops_op_dict[4][1] = (17.007072, 23.1504)
flops_op_dict[4][2] = (11.608512, 15.894528)
flops_op_dict[4][3] = (17.458656, 24.053568)
flops_op_dict[4][4] = (12.060096, 16.797696)
flops_op_dict[4][5] = (18.136032, 25.40832)
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
def search_for_layer(flops_op_dict, arch_def, flops_minimum, flops_maximum):
sta_num = [1, 1, 1, 1, 1]
order = [2, 3, 4, 1, 0, 2, 3, 4, 1, 0]
limits = [3, 3, 3, 2, 2, 4, 4, 4, 4, 4]
size_factor = 224 // 32
base_min_flops = sum([flops_op_dict[i][0][0] for i in range(5)])
base_max_flops = sum([flops_op_dict[i][5][0] for i in range(5)])
if base_min_flops > flops_maximum:
while base_min_flops > flops_maximum and size_factor >= 2:
size_factor = size_factor - 1
flops_minimum = flops_minimum * (7. / size_factor)
flops_maximum = flops_maximum * (7. / size_factor)
if size_factor < 2:
return None, None, None
elif base_max_flops < flops_minimum:
cur_ptr = 0
while base_max_flops < flops_minimum and cur_ptr <= 9:
if sta_num[order[cur_ptr]] >= limits[cur_ptr]:
cur_ptr += 1
continue
base_max_flops = base_max_flops + \
flops_op_dict[order[cur_ptr]][5][1]
sta_num[order[cur_ptr]] += 1
if cur_ptr > 7 and base_max_flops < flops_minimum:
return None, None, None
cur_ptr = 0
while cur_ptr <= 9:
if sta_num[order[cur_ptr]] >= limits[cur_ptr]:
cur_ptr += 1
continue
base_max_flops = base_max_flops + flops_op_dict[order[cur_ptr]][5][1]
if base_max_flops <= flops_maximum:
sta_num[order[cur_ptr]] += 1
else:
break
arch_def = [item[:i] for i, item in zip([1] + sta_num + [1], arch_def)]
# print(arch_def)
return sta_num, arch_def, size_factor * 32
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
import sys
import logging
import torch
import argparse
import torch.nn as nn
from torch import optim as optim
from thop import profile, clever_format
from timm.utils import *
from lib.config import cfg
def get_path_acc(model, path, val_loader, args, val_iters=50):
prec1_m = AverageMeter()
prec5_m = AverageMeter()
with torch.no_grad():
for batch_idx, (input, target) in enumerate(val_loader):
if batch_idx >= val_iters:
break
if not args.prefetcher:
input = input.cuda()
target = target.cuda()
output = model(input, path)
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(
0,
reduce_factor,
reduce_factor).mean(
dim=2)
target = target[0:target.size(0):reduce_factor]
prec1, prec5 = accuracy(output, target, topk=(1, 5))
torch.cuda.synchronize()
prec1_m.update(prec1.item(), output.size(0))
prec5_m.update(prec5.item(), output.size(0))
return (prec1_m.avg, prec5_m.avg)
def get_logger(file_path):
""" Make python logger """
log_format = '%(asctime)s | %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
logger = logging.getLogger('')
formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
file_handler = logging.FileHandler(file_path)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def add_weight_decay_supernet(model, args, weight_decay=1e-5, skip_list=()):
decay = []
no_decay = []
meta_layer_no_decay = []
meta_layer_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(
".bias") or name in skip_list:
if 'meta_layer' in name:
meta_layer_no_decay.append(param)
else:
no_decay.append(param)
else:
if 'meta_layer' in name:
meta_layer_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0., 'lr': args.lr},
{'params': decay, 'weight_decay': weight_decay, 'lr': args.lr},
{'params': meta_layer_no_decay, 'weight_decay': 0., 'lr': args.meta_lr},
{'params': meta_layer_decay, 'weight_decay': 0, 'lr': args.meta_lr},
]
def create_optimizer_supernet(args, model, has_apex, filter_bias_and_bn=True):
opt_lower = args.opt.lower()
weight_decay = args.weight_decay
if 'adamw' in opt_lower or 'radam' in opt_lower:
weight_decay /= args.lr
if weight_decay and filter_bias_and_bn:
parameters = add_weight_decay_supernet(model, args, weight_decay)
weight_decay = 0.
else:
parameters = model.parameters()
if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(
), 'APEX and CUDA required for fused optimizers'
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if opt_lower == 'sgd' or opt_lower == 'nesterov':
optimizer = optim.SGD(
parameters,
momentum=args.momentum,
weight_decay=weight_decay,
nesterov=True)
elif opt_lower == 'momentum':
optimizer = optim.SGD(
parameters,
momentum=args.momentum,
weight_decay=weight_decay,
nesterov=False)
elif opt_lower == 'adam':
optimizer = optim.Adam(
parameters, weight_decay=weight_decay, eps=args.opt_eps)
else:
assert False and "Invalid optimizer"
raise ValueError
return optimizer
def convert_lowercase(cfg):
keys = cfg.keys()
lowercase_keys = [key.lower() for key in keys]
values = [cfg.get(key) for key in keys]
for lowercase_key, value in zip(lowercase_keys, values):
cfg.setdefault(lowercase_key, value)
return cfg
def parse_config_args(exp_name):
parser = argparse.ArgumentParser(description=exp_name)
parser.add_argument(
'--cfg',
type=str,
default='../experiments/workspace/retrain/retrain.yaml',
help='configuration of cream')
parser.add_argument('--local_rank', type=int, default=0,
help='local_rank')
args = parser.parse_args()
cfg.merge_from_file(args.cfg)
converted_cfg = convert_lowercase(cfg)
return args, converted_cfg
def get_model_flops_params(model, input_size=(1, 3, 224, 224)):
input = torch.randn(input_size)
macs, params = profile(deepcopy(model), inputs=(input,), verbose=False)
macs, params = clever_format([macs, params], "%.3f")
return macs, params
def cross_entropy_loss_with_soft_target(pred, soft_target):
logsoftmax = nn.LogSoftmax()
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
def create_supernet_scheduler(cfg, optimizer):
ITERS = cfg.EPOCHS * \
(1280000 / (cfg.NUM_GPU * cfg.DATASET.BATCH_SIZE))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: (
cfg.LR - step / ITERS) if step <= ITERS else 0, last_epoch=-1)
return lr_scheduler, cfg.EPOCHS
yacs
numpy==1.17
opencv-python==4.0.1.24
torchvision==0.2.1
thop
git+https://github.com/sovrasov/flops-counter.pytorch.git
pillow==6.1.0
torch==1.2
timm==0.1.20
tensorboardx==1.2
tensorboard
future
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
import os
import warnings
import datetime
import torch
import numpy as np
import torch.nn as nn
from torchscope import scope
from torch.utils.tensorboard import SummaryWriter
# import timm packages
from timm.optim import create_optimizer
from timm.models import resume_checkpoint
from timm.scheduler import create_scheduler
from timm.data import Dataset, create_loader
from timm.utils import ModelEma, update_summary
from timm.loss import LabelSmoothingCrossEntropy
# import apex as distributed package
try:
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
from apex.parallel import convert_syncbn_model
HAS_APEX = True
except ImportError:
from torch.nn.parallel import DistributedDataParallel as DDP
HAS_APEX = False
# import models and training functions
from lib.core.test import validate
from lib.core.retrain import train_epoch
from lib.models.structures.childnet import gen_childnet
from lib.utils.util import parse_config_args, get_logger, get_model_flops_params
from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def main():
args, cfg = parse_config_args('nni.cream.childnet')
# resolve logging
output_dir = os.path.join(cfg.SAVE_PATH,
"{}-{}".format(datetime.date.today().strftime('%m%d'),
cfg.MODEL))
if not os.path.exists(output_dir):
os.mkdir(output_dir)
if args.local_rank == 0:
logger = get_logger(os.path.join(output_dir, 'retrain.log'))
writer = SummaryWriter(os.path.join(output_dir, 'runs'))
else:
writer, logger = None, None
# retrain model selection
if cfg.NET.SELECTION == 481:
arch_list = [
[0], [
3, 4, 3, 1], [
3, 2, 3, 0], [
3, 3, 3, 1], [
3, 3, 3, 3], [
3, 3, 3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 224
elif cfg.NET.SELECTION == 43:
arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 96
elif cfg.NET.SELECTION == 14:
arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
cfg.DATASET.IMAGE_SIZE = 64
elif cfg.NET.SELECTION == 112:
arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 160
elif cfg.NET.SELECTION == 287:
arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 224
elif cfg.NET.SELECTION == 604:
arch_list = [
[0], [
3, 3, 2, 3, 3], [
3, 2, 3, 2, 3], [
3, 2, 3, 2, 3], [
3, 3, 2, 2, 3, 3], [
3, 3, 2, 3, 3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 224
elif cfg.NET.SELECTION == -1:
arch_list = cfg.NET.INPUT_ARCH
cfg.DATASET.IMAGE_SIZE = 224
else:
raise ValueError("Model Retrain Selection is not Supported!")
# define childnet architecture from arch_list
stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25']
choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25',
'ir_r1_k5_s2_e4_c40_se0.25',
'ir_r1_k3_s2_e6_c80_se0.25',
'ir_r1_k3_s1_e6_c96_se0.25',
'ir_r1_k3_s2_e6_c192_se0.25']
arch_def = [[stem[0]]] + [[choice_block_pool[idx]
for repeat_times in range(len(arch_list[idx + 1]))]
for idx in range(len(choice_block_pool))] + [[stem[1]]]
# generate childnet
model = gen_childnet(
arch_list,
arch_def,
num_classes=cfg.DATASET.NUM_CLASSES,
drop_rate=cfg.NET.DROPOUT_RATE,
global_pool=cfg.NET.GP)
# initialize training parameters
eval_metric = cfg.EVAL_METRICS
best_metric, best_epoch, saver = None, None, None
# initialize distributed parameters
distributed = cfg.NUM_GPU > 1
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
if args.local_rank == 0:
logger.info(
'Training on Process {} with {} GPUs.'.format(
args.local_rank, cfg.NUM_GPU))
# fix random seeds
torch.manual_seed(cfg.SEED)
torch.cuda.manual_seed_all(cfg.SEED)
np.random.seed(cfg.SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# get parameters and FLOPs of model
if args.local_rank == 0:
macs, params = get_model_flops_params(model, input_size=(
1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE))
logger.info(
'[Model-{}] Flops: {} Params: {}'.format(cfg.NET.SELECTION, macs, params))
# create optimizer
model = model.cuda()
optimizer = create_optimizer(cfg, model)
# optionally resume from a checkpoint
resume_state, resume_epoch = {}, None
if cfg.AUTO_RESUME:
resume_state, resume_epoch = resume_checkpoint(model, cfg.RESUME_PATH)
optimizer.load_state_dict(resume_state['optimizer'])
del resume_state
model_ema = None
if cfg.NET.EMA.USE:
model_ema = ModelEma(
model,
decay=cfg.NET.EMA.DECAY,
device='cpu' if cfg.NET.EMA.FORCE_CPU else '',
resume=cfg.RESUME_PATH if cfg.AUTO_RESUME else None)
if distributed:
if cfg.BATCHNORM.SYNC_BN:
try:
if HAS_APEX:
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
model)
if args.local_rank == 0:
logger.info(
'Converted model to use Synchronized BatchNorm.')
except Exception as e:
if args.local_rank == 0:
logger.error(
'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1 with exception {}'.format(e))
if HAS_APEX:
model = DDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
logger.info(
"Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
# can use device str in Torch >= 1.1
model = DDP(model, device_ids=[args.local_rank])
# imagenet train dataset
train_dir = os.path.join(cfg.DATA_DIR, 'train')
if not os.path.exists(train_dir) and args.local_rank == 0:
logger.error('Training folder does not exist at: {}'.format(train_dir))
exit(1)
dataset_train = Dataset(train_dir)
loader_train = create_loader(
dataset_train,
input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
batch_size=cfg.DATASET.BATCH_SIZE,
is_training=True,
color_jitter=cfg.AUGMENTATION.COLOR_JITTER,
auto_augment=cfg.AUGMENTATION.AA,
num_aug_splits=0,
crop_pct=DEFAULT_CROP_PCT,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=cfg.WORKERS,
distributed=distributed,
collate_fn=None,
pin_memory=cfg.DATASET.PIN_MEM,
interpolation='random',
re_mode=cfg.AUGMENTATION.RE_MODE,
re_prob=cfg.AUGMENTATION.RE_PROB
)
# imagenet validation dataset
eval_dir = os.path.join(cfg.DATA_DIR, 'val')
if not os.path.exists(eval_dir) and args.local_rank == 0:
logger.error(
'Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = Dataset(eval_dir)
loader_eval = create_loader(
dataset_eval,
input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE,
is_training=False,
interpolation=cfg.DATASET.INTERPOLATION,
crop_pct=DEFAULT_CROP_PCT,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=cfg.WORKERS,
distributed=distributed,
pin_memory=cfg.DATASET.PIN_MEM
)
# whether to use label smoothing
if cfg.AUGMENTATION.SMOOTHING > 0.:
train_loss_fn = LabelSmoothingCrossEntropy(
smoothing=cfg.AUGMENTATION.SMOOTHING).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = train_loss_fn
# create learning rate scheduler
lr_scheduler, num_epochs = create_scheduler(cfg, optimizer)
start_epoch = resume_epoch if resume_epoch is not None else 0
if start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
logger.info('Scheduled epochs: {}'.format(num_epochs))
try:
best_record, best_ep = 0, 0
for epoch in range(start_epoch, num_epochs):
if distributed:
loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch(
epoch,
model,
loader_train,
optimizer,
train_loss_fn,
cfg,
lr_scheduler=lr_scheduler,
saver=saver,
output_dir=output_dir,
model_ema=model_ema,
logger=logger,
writer=writer,
local_rank=args.local_rank)
eval_metrics = validate(
epoch,
model,
loader_eval,
validate_loss_fn,
cfg,
logger=logger,
writer=writer,
local_rank=args.local_rank)
if model_ema is not None and not cfg.NET.EMA.FORCE_CPU:
ema_eval_metrics = validate(
epoch,
model_ema.ema,
loader_eval,
validate_loss_fn,
cfg,
log_suffix='_EMA',
logger=logger,
writer=writer)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(epoch, train_metrics, eval_metrics, os.path.join(
output_dir, 'summary.csv'), write_header=best_metric is None)
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(
model, optimizer, cfg,
epoch=epoch, model_ema=model_ema, metric=save_metric)
if best_record < eval_metrics[eval_metric]:
best_record = eval_metrics[eval_metric]
best_ep = epoch
if args.local_rank == 0:
logger.info(
'*** Best metric: {0} (epoch {1})'.format(best_record, best_ep))
except KeyboardInterrupt:
pass
if best_metric is not None:
logger.info(
'*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
import os
import warnings
import datetime
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
# import timm packages
from timm.utils import ModelEma
from timm.models import resume_checkpoint
from timm.data import Dataset, create_loader
# import apex as distributed package
try:
from apex.parallel import convert_syncbn_model
from apex.parallel import DistributedDataParallel as DDP
HAS_APEX = True
except ImportError:
from torch.nn.parallel import DistributedDataParallel as DDP
HAS_APEX = False
# import models and training functions
from lib.core.test import validate
from lib.models.structures.childnet import gen_childnet
from lib.utils.util import parse_config_args, get_logger, get_model_flops_params
from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def main():
args, cfg = parse_config_args('child net testing')
# resolve logging
output_dir = os.path.join(cfg.SAVE_PATH,
"{}-{}".format(datetime.date.today().strftime('%m%d'),
cfg.MODEL))
if not os.path.exists(output_dir):
os.mkdir(output_dir)
if args.local_rank == 0:
logger = get_logger(os.path.join(output_dir, 'test.log'))
writer = SummaryWriter(os.path.join(output_dir, 'runs'))
else:
writer, logger = None, None
# retrain model selection
if cfg.NET.SELECTION == 481:
arch_list = [
[0], [
3, 4, 3, 1], [
3, 2, 3, 0], [
3, 3, 3, 1], [
3, 3, 3, 3], [
3, 3, 3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 224
elif cfg.NET.SELECTION == 43:
arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 96
elif cfg.NET.SELECTION == 14:
arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
cfg.DATASET.IMAGE_SIZE = 64
elif cfg.NET.SELECTION == 112:
arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 160
elif cfg.NET.SELECTION == 287:
arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 224
elif cfg.NET.SELECTION == 604:
arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3],
[3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]]
cfg.DATASET.IMAGE_SIZE = 224
else:
raise ValueError("Model Test Selection is not Supported!")
# define childnet architecture from arch_list
stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25']
choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25',
'ir_r1_k5_s2_e4_c40_se0.25',
'ir_r1_k3_s2_e6_c80_se0.25',
'ir_r1_k3_s1_e6_c96_se0.25',
'ir_r1_k5_s2_e6_c192_se0.25']
arch_def = [[stem[0]]] + [[choice_block_pool[idx]
for repeat_times in range(len(arch_list[idx + 1]))]
for idx in range(len(choice_block_pool))] + [[stem[1]]]
# generate childnet
model = gen_childnet(
arch_list,
arch_def,
num_classes=cfg.DATASET.NUM_CLASSES,
drop_rate=cfg.NET.DROPOUT_RATE,
global_pool=cfg.NET.GP)
if args.local_rank == 0:
macs, params = get_model_flops_params(model, input_size=(
1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE))
logger.info(
'[Model-{}] Flops: {} Params: {}'.format(cfg.NET.SELECTION, macs, params))
# initialize distributed parameters
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
if args.local_rank == 0:
logger.info(
"Training on Process {} with {} GPUs.".format(
args.local_rank, cfg.NUM_GPU))
# resume model from checkpoint
assert cfg.AUTO_RESUME is True and os.path.exists(cfg.RESUME_PATH)
_, __ = resume_checkpoint(model, cfg.RESUME_PATH)
model = model.cuda()
model_ema = None
if cfg.NET.EMA.USE:
# Important to create EMA model after cuda(), DP wrapper, and AMP but
# before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=cfg.NET.EMA.DECAY,
device='cpu' if cfg.NET.EMA.FORCE_CPU else '',
resume=cfg.RESUME_PATH)
# imagenet validation dataset
eval_dir = os.path.join(cfg.DATA_DIR, 'val')
if not os.path.exists(eval_dir) and args.local_rank == 0:
logger.error(
'Validation folder does not exist at: {}'.format(eval_dir))
exit(1)
dataset_eval = Dataset(eval_dir)
loader_eval = create_loader(
dataset_eval,
input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE,
is_training=False,
num_workers=cfg.WORKERS,
distributed=True,
pin_memory=cfg.DATASET.PIN_MEM,
crop_pct=DEFAULT_CROP_PCT,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
)
# only test accuracy of model-EMA
validate_loss_fn = nn.CrossEntropyLoss().cuda()
validate(0, model_ema.ema, loader_eval, validate_loss_fn, cfg,
log_suffix='_EMA', logger=logger,
writer=writer, local_rank=args.local_rank)
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Written by Hao Du and Houwen Peng
# email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
import os
import sys
import datetime
import torch
import numpy as np
import torch.nn as nn
# import timm packages
from timm.loss import LabelSmoothingCrossEntropy
from timm.data import Dataset, create_loader
from timm.models import resume_checkpoint
# import apex as distributed package
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.parallel import convert_syncbn_model
USE_APEX = True
except ImportError:
from torch.nn.parallel import DistributedDataParallel as DDP
USE_APEX = False
# import models and training functions
from lib.utils.flops_table import FlopsEst
from lib.models.structures.supernet import gen_supernet
from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from lib.utils.util import parse_config_args, get_logger, \
create_optimizer_supernet, create_supernet_scheduler
from nni.nas.pytorch.callbacks import LRSchedulerCallback
from nni.nas.pytorch.callbacks import ModelCheckpoint
from nni.algorithms.nas.pytorch.cream import CreamSupernetTrainer
from nni.algorithms.nas.pytorch.random import RandomMutator
def main():
args, cfg = parse_config_args('nni.cream.supernet')
# resolve logging
output_dir = os.path.join(cfg.SAVE_PATH,
"{}-{}".format(datetime.date.today().strftime('%m%d'),
cfg.MODEL))
if not os.path.exists(output_dir):
os.mkdir(output_dir)
if args.local_rank == 0:
logger = get_logger(os.path.join(output_dir, "train.log"))
else:
logger = None
# initialize distributed parameters
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
if args.local_rank == 0:
logger.info(
'Training on Process %d with %d GPUs.',
args.local_rank, cfg.NUM_GPU)
# fix random seeds
torch.manual_seed(cfg.SEED)
torch.cuda.manual_seed_all(cfg.SEED)
np.random.seed(cfg.SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# generate supernet
model, sta_num, resolution = gen_supernet(
flops_minimum=cfg.SUPERNET.FLOPS_MINIMUM,
flops_maximum=cfg.SUPERNET.FLOPS_MAXIMUM,
num_classes=cfg.DATASET.NUM_CLASSES,
drop_rate=cfg.NET.DROPOUT_RATE,
global_pool=cfg.NET.GP,
resunit=cfg.SUPERNET.RESUNIT,
dil_conv=cfg.SUPERNET.DIL_CONV,
slice=cfg.SUPERNET.SLICE,
verbose=cfg.VERBOSE,
logger=logger)
# number of choice blocks in supernet
choice_num = len(model.blocks[7])
if args.local_rank == 0:
logger.info('Supernet created, param count: %d', (
sum([m.numel() for m in model.parameters()])))
logger.info('resolution: %d', (resolution))
logger.info('choice number: %d', (choice_num))
# initialize flops look-up table
model_est = FlopsEst(model)
flops_dict, flops_fixed = model_est.flops_dict, model_est.flops_fixed
# optionally resume from a checkpoint
optimizer_state = None
resume_epoch = None
if cfg.AUTO_RESUME:
optimizer_state, resume_epoch = resume_checkpoint(
model, cfg.RESUME_PATH)
# create optimizer and resume from checkpoint
optimizer = create_optimizer_supernet(cfg, model, USE_APEX)
if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state['optimizer'])
model = model.cuda()
# convert model to distributed mode
if cfg.BATCHNORM.SYNC_BN:
try:
if USE_APEX:
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
logger.info('Converted model to use Synchronized BatchNorm.')
except Exception as exception:
logger.info(
'Failed to enable Synchronized BatchNorm. '
'Install Apex or Torch >= 1.1 with Exception %s', exception)
if USE_APEX:
model = DDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
logger.info(
"Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
# can use device str in Torch >= 1.1
model = DDP(model, device_ids=[args.local_rank])
# create learning rate scheduler
lr_scheduler, num_epochs = create_supernet_scheduler(cfg, optimizer)
start_epoch = resume_epoch if resume_epoch is not None else 0
if start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
logger.info('Scheduled epochs: %d', num_epochs)
# imagenet train dataset
train_dir = os.path.join(cfg.DATA_DIR, 'train')
if not os.path.exists(train_dir):
logger.info('Training folder does not exist at: %s', train_dir)
sys.exit()
dataset_train = Dataset(train_dir)
loader_train = create_loader(
dataset_train,
input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
batch_size=cfg.DATASET.BATCH_SIZE,
is_training=True,
use_prefetcher=True,
re_prob=cfg.AUGMENTATION.RE_PROB,
re_mode=cfg.AUGMENTATION.RE_MODE,
color_jitter=cfg.AUGMENTATION.COLOR_JITTER,
interpolation='random',
num_workers=cfg.WORKERS,
distributed=True,
collate_fn=None,
crop_pct=DEFAULT_CROP_PCT,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
)
# imagenet validation dataset
eval_dir = os.path.join(cfg.DATA_DIR, 'val')
if not os.path.isdir(eval_dir):
logger.info('Validation folder does not exist at: %s', eval_dir)
sys.exit()
dataset_eval = Dataset(eval_dir)
loader_eval = create_loader(
dataset_eval,
input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
batch_size=4 * cfg.DATASET.BATCH_SIZE,
is_training=False,
use_prefetcher=True,
num_workers=cfg.WORKERS,
distributed=True,
crop_pct=DEFAULT_CROP_PCT,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
interpolation=cfg.DATASET.INTERPOLATION
)
# whether to use label smoothing
if cfg.AUGMENTATION.SMOOTHING > 0.:
train_loss_fn = LabelSmoothingCrossEntropy(
smoothing=cfg.AUGMENTATION.SMOOTHING).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = train_loss_fn
mutator = RandomMutator(model)
trainer = CreamSupernetTrainer(model, train_loss_fn, validate_loss_fn,
optimizer, num_epochs, loader_train, loader_eval,
mutator=mutator, batch_size=cfg.DATASET.BATCH_SIZE,
log_frequency=cfg.LOG_INTERVAL,
meta_sta_epoch=cfg.SUPERNET.META_STA_EPOCH,
update_iter=cfg.SUPERNET.UPDATE_ITER,
slices=cfg.SUPERNET.SLICE,
pool_size=cfg.SUPERNET.POOL_SIZE,
pick_method=cfg.SUPERNET.PICK_METHOD,
choice_num=choice_num, sta_num=sta_num, acc_gap=cfg.ACC_GAP,
flops_dict=flops_dict, flops_fixed=flops_fixed, local_rank=args.local_rank,
callbacks=[LRSchedulerCallback(lr_scheduler),
ModelCheckpoint(output_dir)])
trainer.train()
if __name__ == '__main__':
main()
[Documentation](https://nni.readthedocs.io/en/latest/NAS/PDARTS.html)
[文档](https://nni.readthedocs.io/zh/latest/NAS/PDARTS.html)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import sys
import time
from argparse import ArgumentParser
import torch
import torch.nn as nn
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint
from nni.algorithms.nas.pytorch.pdarts import PdartsTrainer
# prevent it to be reordered.
if True:
sys.path.append('../../oneshot/darts')
from utils import accuracy
from model import CNN
import datasets
logger = logging.getLogger('nni')
if __name__ == "__main__":
parser = ArgumentParser("pdarts")
parser.add_argument('--add_layers', action='append', type=int,
help='add layers, default: [0, 6, 12]')
parser.add_argument('--dropped_ops', action='append', type=int,
help='drop ops, default: [3, 2, 1]')
parser.add_argument("--nodes", default=4, type=int)
parser.add_argument("--init_layers", default=5, type=int)
parser.add_argument("--channels", default=16, type=int)
parser.add_argument("--batch-size", default=64, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--unrolled", default=False, action="store_true")
args = parser.parse_args()
if args.add_layers is None:
args.add_layers = [0, 6, 12]
if args.dropped_ops is None:
args.dropped_ops = [3, 2, 1]
logger.info("loading data")
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
def model_creator(layers):
model = CNN(32, 3, args.channels, 10, layers, n_nodes=args.nodes)
criterion = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
return model, criterion, optim, lr_scheduler
logger.info("initializing trainer")
trainer = PdartsTrainer(model_creator,
init_layers=args.init_layers,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
pdarts_num_layers=args.add_layers,
pdarts_num_to_drop=args.dropped_ops,
num_epochs=args.epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
unrolled=args.unrolled,
callbacks=[ArchitectureCheckpoint("./checkpoints")])
logger.info("training")
trainer.train()
# TextNAS: A Neural Architecture Search Space tailored for Text Representation
TextNAS by MSRA. Official Release.
[Paper link](https://arxiv.org/abs/1912.10729)
## Preparation
Prepare the word vectors and SST dataset, and organize them in data directory as shown below:
```
textnas
├── data
│ ├── sst
│ │ └── trees
│ │ ├── dev.txt
│ │ ├── test.txt
│ │ └── train.txt
│ └── glove.840B.300d.txt
├── dataloader.py
├── model.py
├── ops.py
├── README.md
├── search.py
└── utils.py
```
The following link might be helpful for finding and downloading the corresponding dataset:
* [GloVe: Global Vectors for Word Representation](https://nlp.stanford.edu/projects/glove/)
* [Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank](https://nlp.stanford.edu/sentiment/)
## Search
```
python search.py
```
After each search epoch, 10 sampled architectures will be tested directly. Their performances are expected to be 40% - 42% after 10 epochs.
By default, 20 sampled architectures will be exported into `checkpoints` directory for next step.
## Retrain
```
sh run_retrain.sh
```
By default, the script will retrain the architecture provided by the author on the SST-2 dataset.
# TextNAS: A Neural Architecture Search Space tailored for Text Representation
TextNAS 由 MSRA 提出 正式版本。
[论文链接](https://arxiv.org/abs/1912.10729)
## 准备
准备词向量和 SST 数据集,并按如下结构放到 data 目录中:
```
textnas
├── data
│ ├── sst
│ │ └── trees
│ │ ├── dev.txt
│ │ ├── test.txt
│ │ └── train.txt
│ └── glove.840B.300d.txt
├── dataloader.py
├── model.py
├── ops.py
├── README.md
├── search.py
└── utils.py
```
以下链接有助于查找和下载相应的数据集:
* [GloVe: Global Vectors for Word Representation](https://nlp.stanford.edu/projects/glove/)
* [Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank](https://nlp.stanford.edu/sentiment/)
## 搜索
```
python search.py
```
在每个搜索 Epoch 后,会直接测试 10 个采样的结构。 10 个 Epoch 后的性能预计为 40% - 42%。
默认情况下,20 个采样结构会被导出到 `checkpoints` 目录中,以便进行下一步处理。
## 重新训练
```
sh run_retrain.sh
```
默认情况下,脚本会重新训练 SST-2 数据集上作者所提供的网络结构。
{
"LayerChoice1": [
false, false, false, false, false, true, false, false
],
"InputChoice2": [
true
],
"LayerChoice3": [
false, false, false, false, false, false, false, true
],
"InputChoice4": [
false
],
"InputChoice5": [
true, false
],
"LayerChoice6": [
false, false, false, true, false, false, false, false
],
"InputChoice7": [
false, false
],
"InputChoice8": [
false, false, true
],
"LayerChoice9": [
false, false, false, false, false, false, true, false
],
"InputChoice10": [
false, true, true
],
"InputChoice11": [
false, false, true, false
],
"LayerChoice12": [
false, true, false, false, false, false, false, false
],
"InputChoice13": [
false, true, false, false
],
"InputChoice14": [
false, false, false, false, true
],
"LayerChoice15": [
false, true, false, false, false, false, false, false
],
"InputChoice16": [
false, false, true, false, true
],
"InputChoice17": [
false, false, false, false, true
],
"LayerChoice18": [
true, false, false, false, false, false, false, false
],
"InputChoice19": [
false, false, true, true, true, true
],
"InputChoice20": [
true, false, false, false, false
],
"LayerChoice21": [
false, false, false, false, false, false, true, false
],
"InputChoice22": [
false, true, true, false, false, false, false
],
"InputChoice23": [
false, true, false, false, false
],
"LayerChoice24": [
false, false, false, false, false, true, false, false
],
"InputChoice25": [
false, true, false, true, true, false, true, true
],
"InputChoice26": [
false, false, true, false, false
],
"LayerChoice27": [
false, false, false, false, false, true, false, false
],
"InputChoice28": [
false, false, false, false, false, true, false, true, true
],
"InputChoice29": [
true, false, false, false, false
],
"LayerChoice30": [
false, false, false, false, false, false, false, true
],
"InputChoice31": [
true, true, false, false, true, false, false, true, true, false
],
"InputChoice32": [
true, false, false, false, false
],
"LayerChoice33": [
false, false, false, false, true, false, false, false
],
"InputChoice34": [
true, false, false, true, true, true, true, false, false, false, false
],
"InputChoice35": [
false, false, false, true, false
],
"LayerChoice36": [
false, true, false, false, false, false, false, false
],
"InputChoice37": [
true, true, false, true, false, true, false, false, true, false, false, false
],
"InputChoice38": [
false, false, false, true, false
],
"LayerChoice39": [
false, false, true, false, false, false, false, false
],
"InputChoice40": [
true, true, false, false, false, false, true, false, false, true, true, false, true
],
"InputChoice41": [
false, false, false, true, false
],
"LayerChoice42": [
true, false, false, false, false, false, false, false
],
"InputChoice43": [
false, false, true, false, false, false, true, true, true, false, true, true, false, false
],
"InputChoice44": [
false, false, false, false, true
],
"LayerChoice45": [
false, false, false, true, false, false, false, false
],
"InputChoice46": [
true, false, false, false, false, false, true, false, false, false, true, true, false, false, true
],
"InputChoice47": [
false, false, false, true, false
],
"LayerChoice48": [
false, false, true, false, false, false, false, false
],
"InputChoice49": [
false, false, false, false, false, false, false, false, false, true, true, false, true, false, true, false
],
"InputChoice50": [
false, false, false, false, true
],
"LayerChoice51": [
false, false, false, false, true, false, false, false
],
"InputChoice52": [
false, true, true, true, true, false, false, true, false, true, false, false, false, false, true, false, false
],
"InputChoice53": [
false, false, true, false, false
],
"LayerChoice54": [
false, false, false, true, false, false, false, false
],
"InputChoice55": [
false, false, false, false, false, true, false, false, false, false, false, false, false, true, true, true, false, true
],
"InputChoice56": [
false, false, true, false, false
],
"LayerChoice57": [
false, false, false, true, false, false, false, false
],
"InputChoice58": [
false, false, false, true, false, false, false, false, false, false, true, false, false, false, true, false, false, false, false
],
"InputChoice59": [
false, true, false, false, false
],
"LayerChoice60": [
false, false, false, false, false, true, false, false
],
"InputChoice61": [
true, true, false, false, false, false, false, false, false, false, true, true, false, false, true, true, true, true, false, false
],
"InputChoice62": [
true, false, false, false, false
],
"LayerChoice63": [
false, false, false, false, false, false, false, true
],
"InputChoice64": [
false, true, true, true, false, false, false, true, false, true, true, true, true, false, true, false, false, false, false, false, false
],
"InputChoice65": [
false, false, false, false, true
],
"LayerChoice66": [
false, false, false, false, false, false, false, true
],
"InputChoice67": [
false, false, true, true, true, true, false, true, false, true, true, false, false, false, false, true, false, false, false, false, false, true
],
"InputChoice68": [
false, false, false, true, false
],
"LayerChoice69": [
false, false, false, true, false, false, false, false
],
"InputChoice70": [
true, false, false, true, false, false, false, true, false, false, false, false, true, false, false, false, true, false, false, false, false, false, false
]
}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import pickle
from collections import Counter
import numpy as np
import torch
from torch.utils import data
logger = logging.getLogger("nni.textnas")
class PTBTree:
WORD_TO_WORD_MAPPING = {
"{": "-LCB-",
"}": "-RCB-"
}
def __init__(self):
self.subtrees = []
self.word = None
self.label = ""
self.parent = None
self.span = (-1, -1)
self.word_vector = None # HOS, store dx1 RNN word vector
self.prediction = None # HOS, store Kx1 prediction vector
def is_leaf(self):
return len(self.subtrees) == 0
def set_by_text(self, text, pos=0, left=0):
depth = 0
right = left
for i in range(pos + 1, len(text)):
char = text[i]
# update the depth
if char == "(":
depth += 1
if depth == 1:
subtree = PTBTree()
subtree.parent = self
subtree.set_by_text(text, i, right)
right = subtree.span[1]
self.span = (left, right)
self.subtrees.append(subtree)
elif char == ")":
depth -= 1
if len(self.subtrees) == 0:
pos = i
for j in range(i, 0, -1):
if text[j] == " ":
pos = j
break
self.word = text[pos + 1:i]
self.span = (left, left + 1)
# we've reached the end of the category that is the root of this subtree
if depth == 0 and char == " " and self.label == "":
self.label = text[pos + 1:i]
# we've reached the end of the scope for this bracket
if depth < 0:
break
# Fix some issues with variation in output, and one error in the treebank
# for a word with a punctuation POS
self.standardise_node()
def standardise_node(self):
if self.word in self.WORD_TO_WORD_MAPPING:
self.word = self.WORD_TO_WORD_MAPPING[self.word]
def __repr__(self, single_line=True, depth=0):
ans = ""
if not single_line and depth > 0:
ans = "\n" + depth * "\t"
ans += "(" + self.label
if self.word is not None:
ans += " " + self.word
for subtree in self.subtrees:
if single_line:
ans += " "
ans += subtree.__repr__(single_line, depth + 1)
ans += ")"
return ans
def read_tree(source):
cur_text = []
depth = 0
while True:
line = source.readline()
# Check if we are out of input
if line == "":
return None
# strip whitespace and only use if this contains something
line = line.strip()
if line == "":
continue
cur_text.append(line)
# Update depth
for char in line:
if char == "(":
depth += 1
elif char == ")":
depth -= 1
# At depth 0 we have a complete tree
if depth == 0:
tree = PTBTree()
tree.set_by_text(" ".join(cur_text))
return tree
return None
def read_trees(source, max_sents=-1):
with open(source) as fp:
trees = []
while True:
tree = read_tree(fp)
if tree is None:
break
trees.append(tree)
if len(trees) >= max_sents > 0:
break
return trees
class SSTDataset(data.Dataset):
def __init__(self, sents, mask, labels):
self.sents = sents
self.labels = labels
self.mask = mask
def __getitem__(self, index):
return (self.sents[index], self.mask[index]), self.labels[index]
def __len__(self):
return len(self.sents)
def sst_get_id_input(content, word_id_dict, max_input_length):
words = content.split(" ")
sentence = [word_id_dict["<pad>"]] * max_input_length
mask = [0] * max_input_length
unknown = word_id_dict["<unknown>"]
for i, word in enumerate(words[:max_input_length]):
sentence[i] = word_id_dict.get(word, unknown)
mask[i] = 1
return sentence, mask
def sst_get_phrases(trees, sample_ratio=1.0, is_binary=False, only_sentence=False):
all_phrases = []
for tree in trees:
if only_sentence:
sentence = get_sentence_by_tree(tree)
label = int(tree.label)
pair = (sentence, label)
all_phrases.append(pair)
else:
phrases = get_phrases_by_tree(tree)
sentence = get_sentence_by_tree(tree)
pair = (sentence, int(tree.label))
all_phrases.append(pair)
all_phrases += phrases
if sample_ratio < 1.:
np.random.shuffle(all_phrases)
result_phrases = []
for pair in all_phrases:
if is_binary:
phrase, label = pair
if label <= 1:
pair = (phrase, 0)
elif label >= 3:
pair = (phrase, 1)
else:
continue
if sample_ratio == 1.:
result_phrases.append(pair)
else:
rand_portion = np.random.random()
if rand_portion < sample_ratio:
result_phrases.append(pair)
return result_phrases
def get_phrases_by_tree(tree):
phrases = []
if tree is None:
return phrases
if tree.is_leaf():
pair = (tree.word, int(tree.label))
phrases.append(pair)
return phrases
left_child_phrases = get_phrases_by_tree(tree.subtrees[0])
right_child_phrases = get_phrases_by_tree(tree.subtrees[1])
phrases.extend(left_child_phrases)
phrases.extend(right_child_phrases)
sentence = get_sentence_by_tree(tree)
pair = (sentence, int(tree.label))
phrases.append(pair)
return phrases
def get_sentence_by_tree(tree):
if tree is None:
return ""
if tree.is_leaf():
return tree.word
left_sentence = get_sentence_by_tree(tree.subtrees[0])
right_sentence = get_sentence_by_tree(tree.subtrees[1])
sentence = left_sentence + " " + right_sentence
return sentence.strip()
def get_word_id_dict(word_num_dict, word_id_dict, min_count):
z = [k for k in sorted(word_num_dict.keys())]
for word in z:
count = word_num_dict[word]
if count >= min_count:
index = len(word_id_dict)
if word not in word_id_dict:
word_id_dict[word] = index
return word_id_dict
def load_word_num_dict(phrases, word_num_dict):
for sentence, _ in phrases:
words = sentence.split(" ")
for cur_word in words:
word = cur_word.strip()
word_num_dict[word] += 1
return word_num_dict
def init_trainable_embedding(embedding_path, word_id_dict, embed_dim=300):
word_embed_model = load_glove_model(embedding_path, embed_dim)
assert word_embed_model["pool"].shape[1] == embed_dim
embedding = np.random.random([len(word_id_dict), embed_dim]).astype(np.float32) / 2.0 - 0.25
embedding[0] = np.zeros(embed_dim) # PAD
embedding[1] = (np.random.rand(embed_dim) - 0.5) / 2 # UNK
for word in sorted(word_id_dict.keys()):
idx = word_id_dict[word]
if idx == 0 or idx == 1:
continue
if word in word_embed_model["mapping"]:
embedding[idx] = word_embed_model["pool"][word_embed_model["mapping"][word]]
else:
embedding[idx] = np.random.rand(embed_dim) / 2.0 - 0.25
return embedding
def sst_get_trainable_data(phrases, word_id_dict, max_input_length):
texts, labels, mask = [], [], []
for phrase, label in phrases:
if not phrase.split():
continue
phrase_split, mask_split = sst_get_id_input(phrase, word_id_dict, max_input_length)
texts.append(phrase_split)
labels.append(int(label))
mask.append(mask_split) # field_input is mask
labels = np.array(labels, dtype=np.int64)
texts = np.reshape(texts, [-1, max_input_length]).astype(np.int32)
mask = np.reshape(mask, [-1, max_input_length]).astype(np.int32)
return SSTDataset(texts, mask, labels)
def load_glove_model(filename, embed_dim):
if os.path.exists(filename + ".cache"):
logger.info("Found cache. Loading...")
with open(filename + ".cache", "rb") as fp:
return pickle.load(fp)
embedding = {"mapping": dict(), "pool": []}
with open(filename) as f:
for i, line in enumerate(f):
line = line.rstrip("\n")
vocab_word, *vec = line.rsplit(" ", maxsplit=embed_dim)
assert len(vec) == 300, "Unexpected line: '%s'" % line
embedding["pool"].append(np.array(list(map(float, vec)), dtype=np.float32))
embedding["mapping"][vocab_word] = i
embedding["pool"] = np.stack(embedding["pool"])
with open(filename + ".cache", "wb") as fp:
pickle.dump(embedding, fp)
return embedding
def read_data_sst(data_path, max_input_length=64, min_count=1, train_with_valid=False,
train_ratio=1., valid_ratio=1., is_binary=False, only_sentence=False):
word_id_dict = dict()
word_num_dict = Counter()
sst_path = os.path.join(data_path, "sst")
logger.info("Reading SST data...")
train_file_name = os.path.join(sst_path, "trees", "train.txt")
valid_file_name = os.path.join(sst_path, "trees", "dev.txt")
test_file_name = os.path.join(sst_path, "trees", "test.txt")
train_trees = read_trees(train_file_name)
train_phrases = sst_get_phrases(train_trees, train_ratio, is_binary, only_sentence)
logger.info("Finish load train phrases.")
valid_trees = read_trees(valid_file_name)
valid_phrases = sst_get_phrases(valid_trees, valid_ratio, is_binary, only_sentence)
logger.info("Finish load valid phrases.")
if train_with_valid:
train_phrases += valid_phrases
test_trees = read_trees(test_file_name)
test_phrases = sst_get_phrases(test_trees, valid_ratio, is_binary, only_sentence=True)
logger.info("Finish load test phrases.")
# get word_id_dict
word_id_dict["<pad>"] = 0
word_id_dict["<unknown>"] = 1
load_word_num_dict(train_phrases, word_num_dict)
logger.info("Finish load train words: %d.", len(word_num_dict))
load_word_num_dict(valid_phrases, word_num_dict)
load_word_num_dict(test_phrases, word_num_dict)
logger.info("Finish load valid+test words: %d.", len(word_num_dict))
word_id_dict = get_word_id_dict(word_num_dict, word_id_dict, min_count)
logger.info("After trim vocab length: %d.", len(word_id_dict))
logger.info("Loading embedding...")
embedding = init_trainable_embedding(os.path.join(data_path, "glove.840B.300d.txt"), word_id_dict)
logger.info("Finish initialize word embedding.")
dataset_train = sst_get_trainable_data(train_phrases, word_id_dict, max_input_length)
logger.info("Loaded %d training samples.", len(dataset_train))
dataset_valid = sst_get_trainable_data(valid_phrases, word_id_dict, max_input_length)
logger.info("Loaded %d validation samples.", len(dataset_valid))
dataset_test = sst_get_trainable_data(test_phrases, word_id_dict, max_input_length)
logger.info("Loaded %d test samples.", len(dataset_test))
return dataset_train, dataset_valid, dataset_test, torch.from_numpy(embedding)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
import torch
import torch.nn as nn
from nni.nas.pytorch import mutables
from ops import ConvBN, LinearCombine, AvgPool, MaxPool, RNN, Attention, BatchNorm
from utils import GlobalMaxPool, GlobalAvgPool
class Layer(mutables.MutableScope):
def __init__(self, key, prev_keys, hidden_units, choose_from_k, cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask):
super(Layer, self).__init__(key)
def conv_shortcut(kernel_size):
return ConvBN(kernel_size, hidden_units, hidden_units, cnn_keep_prob, False, True)
self.n_candidates = len(prev_keys)
if self.n_candidates:
self.prec = mutables.InputChoice(choose_from=prev_keys[-choose_from_k:], n_chosen=1)
else:
# first layer, skip input choice
self.prec = None
self.op = mutables.LayerChoice([
conv_shortcut(1),
conv_shortcut(3),
conv_shortcut(5),
conv_shortcut(7),
AvgPool(3, False, True),
MaxPool(3, False, True),
RNN(hidden_units, lstm_keep_prob),
Attention(hidden_units, 4, att_keep_prob, att_mask)
])
if self.n_candidates:
self.skipconnect = mutables.InputChoice(choose_from=prev_keys)
else:
self.skipconnect = None
self.bn = BatchNorm(hidden_units, False, True)
def forward(self, last_layer, prev_layers, mask):
# pass an extra last_layer to deal with layer 0 (prev_layers is empty)
if self.prec is None:
prec = last_layer
else:
prec = self.prec(prev_layers[-self.prec.n_candidates:]) # skip first
out = self.op(prec, mask)
if self.skipconnect is not None:
connection = self.skipconnect(prev_layers[-self.skipconnect.n_candidates:])
if connection is not None:
out += connection
out = self.bn(out, mask)
return out
class Model(nn.Module):
def __init__(self, embedding, hidden_units=256, num_layers=24, num_classes=5, choose_from_k=5,
lstm_keep_prob=0.5, cnn_keep_prob=0.5, att_keep_prob=0.5, att_mask=True,
embed_keep_prob=0.5, final_output_keep_prob=1.0, global_pool="avg"):
super(Model, self).__init__()
self.embedding = nn.Embedding.from_pretrained(embedding, freeze=False)
self.hidden_units = hidden_units
self.num_layers = num_layers
self.num_classes = num_classes
self.init_conv = ConvBN(1, self.embedding.embedding_dim, hidden_units, cnn_keep_prob, False, True)
self.layers = nn.ModuleList()
candidate_keys_pool = []
for layer_id in range(self.num_layers):
k = "layer_{}".format(layer_id)
self.layers.append(Layer(k, candidate_keys_pool, hidden_units, choose_from_k,
cnn_keep_prob, lstm_keep_prob, att_keep_prob, att_mask))
candidate_keys_pool.append(k)
self.linear_combine = LinearCombine(self.num_layers)
self.linear_out = nn.Linear(self.hidden_units, self.num_classes)
self.embed_dropout = nn.Dropout(p=1 - embed_keep_prob)
self.output_dropout = nn.Dropout(p=1 - final_output_keep_prob)
assert global_pool in ["max", "avg"]
if global_pool == "max":
self.global_pool = GlobalMaxPool()
elif global_pool == "avg":
self.global_pool = GlobalAvgPool()
def forward(self, inputs):
sent_ids, mask = inputs
seq = self.embedding(sent_ids.long())
seq = self.embed_dropout(seq)
seq = torch.transpose(seq, 1, 2) # from (N, L, C) -> (N, C, L)
x = self.init_conv(seq, mask)
prev_layers = []
for layer in self.layers:
x = layer(x, prev_layers, mask)
prev_layers.append(x)
x = self.linear_combine(torch.stack(prev_layers))
x = self.global_pool(x, mask)
x = self.output_dropout(x)
x = self.linear_out(x)
return x
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn.functional as F
from torch import nn
from utils import get_length, INF
class Mask(nn.Module):
def forward(self, seq, mask):
# seq: (N, C, L)
# mask: (N, L)
seq_mask = torch.unsqueeze(mask, 2)
seq_mask = torch.transpose(seq_mask.repeat(1, 1, seq.size()[1]), 1, 2)
return seq.where(torch.eq(seq_mask, 1), torch.zeros_like(seq))
class BatchNorm(nn.Module):
def __init__(self, num_features, pre_mask, post_mask, eps=1e-5, decay=0.9, affine=True):
super(BatchNorm, self).__init__()
self.mask_opt = Mask()
self.pre_mask = pre_mask
self.post_mask = post_mask
self.bn = nn.BatchNorm1d(num_features, eps=eps, momentum=1.0 - decay, affine=affine)
def forward(self, seq, mask):
if self.pre_mask:
seq = self.mask_opt(seq, mask)
seq = self.bn(seq)
if self.post_mask:
seq = self.mask_opt(seq, mask)
return seq
class ConvBN(nn.Module):
def __init__(self, kernal_size, in_channels, out_channels, cnn_keep_prob,
pre_mask, post_mask, with_bn=True, with_relu=True):
super(ConvBN, self).__init__()
self.mask_opt = Mask()
self.pre_mask = pre_mask
self.post_mask = post_mask
self.with_bn = with_bn
self.with_relu = with_relu
self.conv = nn.Conv1d(in_channels, out_channels, kernal_size, 1, bias=True, padding=(kernal_size - 1) // 2)
self.dropout = nn.Dropout(p=(1 - cnn_keep_prob))
if with_bn:
self.bn = BatchNorm(out_channels, not post_mask, True)
if with_relu:
self.relu = nn.ReLU()
def forward(self, seq, mask):
if self.pre_mask:
seq = self.mask_opt(seq, mask)
seq = self.conv(seq)
if self.post_mask:
seq = self.mask_opt(seq, mask)
if self.with_bn:
seq = self.bn(seq, mask)
if self.with_relu:
seq = self.relu(seq)
seq = self.dropout(seq)
return seq
class AvgPool(nn.Module):
def __init__(self, kernal_size, pre_mask, post_mask):
super(AvgPool, self).__init__()
self.avg_pool = nn.AvgPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2)
self.pre_mask = pre_mask
self.post_mask = post_mask
self.mask_opt = Mask()
def forward(self, seq, mask):
if self.pre_mask:
seq = self.mask_opt(seq, mask)
seq = self.avg_pool(seq)
if self.post_mask:
seq = self.mask_opt(seq, mask)
return seq
class MaxPool(nn.Module):
def __init__(self, kernal_size, pre_mask, post_mask):
super(MaxPool, self).__init__()
self.max_pool = nn.MaxPool1d(kernal_size, 1, padding=(kernal_size - 1) // 2)
self.pre_mask = pre_mask
self.post_mask = post_mask
self.mask_opt = Mask()
def forward(self, seq, mask):
if self.pre_mask:
seq = self.mask_opt(seq, mask)
seq = self.max_pool(seq)
if self.post_mask:
seq = self.mask_opt(seq, mask)
return seq
class Attention(nn.Module):
def __init__(self, num_units, num_heads, keep_prob, is_mask):
super(Attention, self).__init__()
self.num_heads = num_heads
self.keep_prob = keep_prob
self.linear_q = nn.Linear(num_units, num_units)
self.linear_k = nn.Linear(num_units, num_units)
self.linear_v = nn.Linear(num_units, num_units)
self.bn = BatchNorm(num_units, True, is_mask)
self.dropout = nn.Dropout(p=1 - self.keep_prob)
def forward(self, seq, mask):
in_c = seq.size()[1]
seq = torch.transpose(seq, 1, 2) # (N, L, C)
queries = seq
keys = seq
num_heads = self.num_heads
# T_q = T_k = L
Q = F.relu(self.linear_q(seq)) # (N, T_q, C)
K = F.relu(self.linear_k(seq)) # (N, T_k, C)
V = F.relu(self.linear_v(seq)) # (N, T_k, C)
# Split and concat
Q_ = torch.cat(torch.split(Q, in_c // num_heads, dim=2), dim=0) # (h*N, T_q, C/h)
K_ = torch.cat(torch.split(K, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
V_ = torch.cat(torch.split(V, in_c // num_heads, dim=2), dim=0) # (h*N, T_k, C/h)
# Multiplication
outputs = torch.matmul(Q_, K_.transpose(1, 2)) # (h*N, T_q, T_k)
# Scale
outputs = outputs / (K_.size()[-1] ** 0.5)
# Key Masking
key_masks = mask.repeat(num_heads, 1) # (h*N, T_k)
key_masks = torch.unsqueeze(key_masks, 1) # (h*N, 1, T_k)
key_masks = key_masks.repeat(1, queries.size()[1], 1) # (h*N, T_q, T_k)
paddings = torch.ones_like(outputs) * (-INF) # extremely small value
outputs = torch.where(torch.eq(key_masks, 0), paddings, outputs)
query_masks = mask.repeat(num_heads, 1) # (h*N, T_q)
query_masks = torch.unsqueeze(query_masks, -1) # (h*N, T_q, 1)
query_masks = query_masks.repeat(1, 1, keys.size()[1]).float() # (h*N, T_q, T_k)
att_scores = F.softmax(outputs, dim=-1) * query_masks # (h*N, T_q, T_k)
att_scores = self.dropout(att_scores)
# Weighted sum
x_outputs = torch.matmul(att_scores, V_) # (h*N, T_q, C/h)
# Restore shape
x_outputs = torch.cat(
torch.split(x_outputs, x_outputs.size()[0] // num_heads, dim=0),
dim=2) # (N, T_q, C)
x = torch.transpose(x_outputs, 1, 2) # (N, C, L)
x = self.bn(x, mask)
return x
class RNN(nn.Module):
def __init__(self, hidden_size, output_keep_prob):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.bid_rnn = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
self.output_keep_prob = output_keep_prob
self.out_dropout = nn.Dropout(p=(1 - self.output_keep_prob))
def forward(self, seq, mask):
# seq: (N, C, L)
# mask: (N, L)
max_len = seq.size()[2]
length = get_length(mask)
seq = torch.transpose(seq, 1, 2) # to (N, L, C)
packed_seq = nn.utils.rnn.pack_padded_sequence(seq, length, batch_first=True,
enforce_sorted=False)
outputs, _ = self.bid_rnn(packed_seq)
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True,
total_length=max_len)[0]
outputs = outputs.view(-1, max_len, 2, self.hidden_size).sum(2) # (N, L, C)
outputs = self.out_dropout(outputs) # output dropout
return torch.transpose(outputs, 1, 2) # back to: (N, C, L)
class LinearCombine(nn.Module):
def __init__(self, layers_num, trainable=True, input_aware=False, word_level=False):
super(LinearCombine, self).__init__()
self.input_aware = input_aware
self.word_level = word_level
if input_aware:
raise NotImplementedError("Input aware is not supported.")
self.w = nn.Parameter(torch.full((layers_num, 1, 1, 1), 1.0 / layers_num),
requires_grad=trainable)
def forward(self, seq):
nw = F.softmax(self.w, dim=0)
seq = torch.mul(seq, nw)
seq = torch.sum(seq, dim=0)
return seq
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment