"docs/en_US/SklearnExamples.md" did not exist on "d76d379b1c8e00ee55ea5aa0405e392ac228a214"
Unverified Commit affb2118 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

support proxylessnas with NNI NAS APIs (#1863)

parent fdcd877f
...@@ -19,6 +19,7 @@ NNI supports below NAS algorithms now and is adding more. User can reproduce an ...@@ -19,6 +19,7 @@ NNI supports below NAS algorithms now and is adding more. User can reproduce an
| [P-DARTS](PDARTS.md) | [Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation](https://arxiv.org/abs/1904.12760) is based on DARTS. It introduces an efficient algorithm which allows the depth of searched architectures to grow gradually during the training procedure. | | [P-DARTS](PDARTS.md) | [Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation](https://arxiv.org/abs/1904.12760) is based on DARTS. It introduces an efficient algorithm which allows the depth of searched architectures to grow gradually during the training procedure. |
| [SPOS](SPOS.md) | [Single Path One-Shot Neural Architecture Search with Uniform Sampling](https://arxiv.org/abs/1904.00420) constructs a simplified supernet trained with an uniform path sampling method, and applies an evolutionary algorithm to efficiently search for the best-performing architectures. | | [SPOS](SPOS.md) | [Single Path One-Shot Neural Architecture Search with Uniform Sampling](https://arxiv.org/abs/1904.00420) constructs a simplified supernet trained with an uniform path sampling method, and applies an evolutionary algorithm to efficiently search for the best-performing architectures. |
| [CDARTS](CDARTS.md) | [Cyclic Differentiable Architecture Search](https://arxiv.org/abs/****) builds a cyclic feedback mechanism between the search and evaluation networks. It introduces a cyclic differentiable architecture search framework which integrates the two networks into a unified architecture.| | [CDARTS](CDARTS.md) | [Cyclic Differentiable Architecture Search](https://arxiv.org/abs/****) builds a cyclic feedback mechanism between the search and evaluation networks. It introduces a cyclic differentiable architecture search framework which integrates the two networks into a unified architecture.|
| [ProxylessNAS](Proxylessnas.md) | [ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware](https://arxiv.org/abs/1812.00332).|
One-shot algorithms run **standalone without nnictl**. Only PyTorch version has been implemented. Tensorflow 2.x will be supported in future release. One-shot algorithms run **standalone without nnictl**. Only PyTorch version has been implemented. Tensorflow 2.x will be supported in future release.
......
# ProxylessNAS on NNI
## Introduction
The paper [ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware](https://arxiv.org/pdf/1812.00332.pdf) removes proxy, it directly learns the architectures for large-scale target tasks and target hardware platforms. They address high memory consumption issue of differentiable NAS and reduce the computational cost to the same level of regular training while still allowing a large candidate set. Please refer to the paper for the details.
## Usage
To use ProxylessNAS training/searching approach, users need to specify search space in their model using [NNI NAS interface](NasGuide.md), e.g., `LayerChoice`, `InputChoice`. After defining and instantiating the model, the following work can be leaved to ProxylessNasTrainer by instantiating the trainer and passing the model to it.
```python
trainer = ProxylessNasTrainer(model,
model_optim=optimizer,
train_loader=data_provider.train,
valid_loader=data_provider.valid,
device=device,
warmup=True,
ckpt_path=args.checkpoint_path,
arch_path=args.arch_path)
trainer.train()
trainer.export(args.arch_path)
```
The complete example code can be found [here](https://github.com/microsoft/nni/tree/master/examples/nas/proxylessnas).
**Input arguments of ProxylessNasTrainer**
* **model** (*PyTorch model, required*) - The model that users want to tune/search. It has mutables to specify search space.
* **model_optim** (*PyTorch optimizer, required*) - The optimizer users want to train the model.
* **device** (*device, required*) - The devices that users provide to do the train/search. The trainer applies data parallel on the model for users.
* **train_loader** (*PyTorch data loader, required*) - The data loader for training set.
* **valid_loader** (*PyTorch data loader, required*) - The data loader for validation set.
* **label_smoothing** (*float, optional, default = 0.1*) - The degree of label smoothing.
* **n_epochs** (*int, optional, default = 120*) - The number of epochs to train/search.
* **init_lr** (*float, optional, default = 0.025*) - The initial learning rate for training the model.
* **binary_mode** (*'two', 'full', or 'full_v2', optional, default = 'full_v2'*) - The forward/backward mode for the binary weights in mutator. 'full' means forward all the candidate ops, 'two' means only forward two sampled ops, 'full_v2' means recomputing the inactive ops during backward.
* **arch_init_type** (*'normal' or 'uniform', optional, default = 'normal'*) - The way to init architecture parameters.
* **arch_init_ratio** (*float, optional, default = 1e-3*) - The ratio to init architecture parameters.
* **arch_optim_lr** (*float, optional, default = 1e-3*) - The learning rate of the architecture parameters optimizer.
* **arch_weight_decay** (*float, optional, default = 0*) - Weight decay of the architecture parameters optimizer.
* **grad_update_arch_param_every** (*int, optional, default = 5*) - Update architecture weights every this number of minibatches.
* **grad_update_steps** (*int, optional, default = 1*) - During each update of architecture weights, the number of steps to train architecture weights.
* **warmup** (*bool, optional, default = True*) - Whether to do warmup.
* **warmup_epochs** (*int, optional, default = 25*) - The number of epochs to do during warmup.
* **arch_valid_frequency** (*int, optional, default = 1*) - The frequency of printing validation result.
* **load_ckpt** (*bool, optional, default = False*) - Whether to load checkpoint.
* **ckpt_path** (*str, optional, default = None*) - checkpoint path, if load_ckpt is True, ckpt_path cannot be None.
* **arch_path** (*str, optional, default = None*) - The path to store chosen architecture.
## Implementation
The implementation on NNI is based on the [offical implementation](https://github.com/mit-han-lab/ProxylessNAS). The official implementation supports two training approaches: gradient descent and RL based, and support different targeted hardware, including 'mobile', 'cpu', 'gpu8', 'flops'. In our current implementation on NNI, gradient descent training approach is supported, but has not supported different hardwares. The complete support is ongoing.
Below we will describe implementation details. Like other one-shot NAS algorithms on NNI, ProxylessNAS is composed of two parts: *search space* and *training approach*. For users to flexibly define their own search space and use built-in ProxylessNAS training approach, we put the specified search space in [example code](https://github.com/microsoft/nni/tree/master/examples/nas/proxylessnas) using [NNI NAS interface](NasGuide.md), and put the training approach in [SDK](https://github.com/microsoft/nni/tree/master/src/sdk/pynni/nni/nas/pytorch/proxylessnas).
![](../../img/proxylessnas.png)
ProxylessNAS training approach is composed of ProxylessNasMutator and ProxylessNasTrainer. ProxylessNasMutator instantiates MixedOp for each mutable (i.e., LayerChoice), and manage architecture weights in MixedOp. **For DataParallel**, architecture weights should be included in user model. Specifically, in ProxylessNAS implementation, we add MixedOp to the corresponding mutable (i.e., LayerChoice) as a member variable. The mutator also exposes two member functions, i.e., `arch_requires_grad`, `arch_disable_grad`, for the trainer to control the training of architecture weights.
ProxylessNasMutator also implements the forward logic of the mutables (i.e., LayerChoice).
## Reproduce Results
Ongoing...
...@@ -24,4 +24,5 @@ For details, please refer to the following tutorials: ...@@ -24,4 +24,5 @@ For details, please refer to the following tutorials:
P-DARTS <NAS/PDARTS> P-DARTS <NAS/PDARTS>
SPOS <NAS/SPOS> SPOS <NAS/SPOS>
CDARTS <NAS/CDARTS> CDARTS <NAS/CDARTS>
ProxylessNAS <NAS/Proxylessnas>
API Reference <NAS/NasReference> API Reference <NAS/NasReference>
import os
import numpy as np
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
def get_split_list(in_dim, child_num):
in_dim_list = [in_dim // child_num] * child_num
for _i in range(in_dim % child_num):
in_dim_list[_i] += 1
return in_dim_list
class DataProvider:
VALID_SEED = 0 # random seed for the validation set
@staticmethod
def name():
""" Return name of the dataset """
raise NotImplementedError
@property
def data_shape(self):
""" Return shape as python list of one data entry """
raise NotImplementedError
@property
def n_classes(self):
""" Return `int` of num classes """
raise NotImplementedError
@property
def save_path(self):
""" local path to save the data """
raise NotImplementedError
@property
def data_url(self):
""" link to download the data """
raise NotImplementedError
@staticmethod
def random_sample_valid_set(train_labels, valid_size, n_classes):
train_size = len(train_labels)
assert train_size > valid_size
g = torch.Generator()
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set
rand_indexes = torch.randperm(train_size, generator=g).tolist()
train_indexes, valid_indexes = [], []
per_class_remain = get_split_list(valid_size, n_classes)
for idx in rand_indexes:
label = train_labels[idx]
if isinstance(label, float):
label = int(label)
elif isinstance(label, np.ndarray):
label = np.argmax(label)
else:
assert isinstance(label, int)
if per_class_remain[label] > 0:
valid_indexes.append(idx)
per_class_remain[label] -= 1
else:
train_indexes.append(idx)
return train_indexes, valid_indexes
class ImagenetDataProvider(DataProvider):
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None,
n_worker=32, resize_scale=0.08, distort_color=None):
self._save_path = save_path
train_transforms = self.build_train_transform(distort_color, resize_scale)
train_dataset = datasets.ImageFolder(self.train_path, train_transforms)
if valid_size is not None:
if isinstance(valid_size, float):
valid_size = int(valid_size * len(train_dataset))
else:
assert isinstance(valid_size, int), 'invalid valid_size: %s' % valid_size
train_indexes, valid_indexes = self.random_sample_valid_set(
[cls for _, cls in train_dataset.samples], valid_size, self.n_classes,
)
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
valid_dataset = datasets.ImageFolder(self.train_path, transforms.Compose([
transforms.Resize(self.resize_value),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
self.normalize,
]))
self.train = torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True,
)
self.valid = torch.utils.data.DataLoader(
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
num_workers=n_worker, pin_memory=True,
)
else:
self.train = torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True,
)
self.valid = None
self.test = torch.utils.data.DataLoader(
datasets.ImageFolder(self.valid_path, transforms.Compose([
transforms.Resize(self.resize_value),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
self.normalize,
])), batch_size=test_batch_size, shuffle=False, num_workers=n_worker, pin_memory=True,
)
if self.valid is None:
self.valid = self.test
@staticmethod
def name():
return 'imagenet'
@property
def data_shape(self):
return 3, self.image_size, self.image_size # C, H, W
@property
def n_classes(self):
return 1000
@property
def save_path(self):
if self._save_path is None:
self._save_path = '/dataset/imagenet'
return self._save_path
@property
def data_url(self):
raise ValueError('unable to download ImageNet')
@property
def train_path(self):
return os.path.join(self.save_path, 'train')
@property
def valid_path(self):
return os.path.join(self._save_path, 'val')
@property
def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def build_train_transform(self, distort_color, resize_scale):
print('Color jitter: %s' % distort_color)
if distort_color == 'strong':
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
elif distort_color == 'normal':
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
else:
color_transform = None
if color_transform is None:
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
self.normalize,
])
else:
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(self.image_size, scale=(resize_scale, 1.0)),
transforms.RandomHorizontalFlip(),
color_transform,
transforms.ToTensor(),
self.normalize,
])
return train_transforms
@property
def resize_value(self):
return 256
@property
def image_size(self):
return 224
\ No newline at end of file
import os
import sys
import logging
from argparse import ArgumentParser
import torch
import datasets
from putils import get_parameters
from model import SearchMobileNet
from nni.nas.pytorch.proxylessnas import ProxylessNasTrainer
from retrain import Retrain
logger = logging.getLogger('nni_proxylessnas')
if __name__ == "__main__":
parser = ArgumentParser("proxylessnas")
# configurations of the model
parser.add_argument("--n_cell_stages", default='4,4,4,4,4,1', type=str)
parser.add_argument("--stride_stages", default='2,2,2,1,2,1', type=str)
parser.add_argument("--width_stages", default='24,40,80,96,192,320', type=str)
parser.add_argument("--bn_momentum", default=0.1, type=float)
parser.add_argument("--bn_eps", default=1e-3, type=float)
parser.add_argument("--dropout_rate", default=0, type=float)
parser.add_argument("--no_decay_keys", default='bn', type=str, choices=[None, 'bn', 'bn#bias'])
# configurations of imagenet dataset
parser.add_argument("--data_path", default='/data/imagenet/', type=str)
parser.add_argument("--train_batch_size", default=256, type=int)
parser.add_argument("--test_batch_size", default=500, type=int)
parser.add_argument("--n_worker", default=32, type=int)
parser.add_argument("--resize_scale", default=0.08, type=float)
parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None'])
# configurations for training mode
parser.add_argument("--train_mode", default='search', type=str, choices=['search', 'retrain'])
# configurations for search
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
# configurations for retrain
parser.add_argument("--exported_arch_path", default=None, type=str)
args = parser.parse_args()
if args.train_mode == 'retrain' and args.exported_arch_path is None:
logger.error('When --train_mode is retrain, --exported_arch_path must be specified.')
sys.exit(-1)
model = SearchMobileNet(width_stages=[int(i) for i in args.width_stages.split(',')],
n_cell_stages=[int(i) for i in args.n_cell_stages.split(',')],
stride_stages=[int(i) for i in args.stride_stages.split(',')],
n_classes=1000,
dropout_rate=args.dropout_rate,
bn_param=(args.bn_momentum, args.bn_eps))
logger.info('SearchMobileNet model create done')
model.init_model()
logger.info('SearchMobileNet model init done')
# move network to GPU if available
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
logger.info('Creating data provider...')
data_provider = datasets.ImagenetDataProvider(save_path=args.data_path,
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size,
valid_size=None,
n_worker=args.n_worker,
resize_scale=args.resize_scale,
distort_color=args.distort_color)
logger.info('Creating data provider done')
if args.no_decay_keys:
keys = args.no_decay_keys
momentum, nesterov = 0.9, True
optimizer = torch.optim.SGD([
{'params': get_parameters(model, keys, mode='exclude'), 'weight_decay': 4e-5},
{'params': get_parameters(model, keys, mode='include'), 'weight_decay': 0},
], lr=0.05, momentum=momentum, nesterov=nesterov)
else:
optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)
if args.train_mode == 'search':
# this is architecture search
logger.info('Creating ProxylessNasTrainer...')
trainer = ProxylessNasTrainer(model,
model_optim=optimizer,
train_loader=data_provider.train,
valid_loader=data_provider.valid,
device=device,
warmup=True,
ckpt_path=args.checkpoint_path,
arch_path=args.arch_path)
logger.info('Start to train with ProxylessNasTrainer...')
trainer.train()
logger.info('Training done')
trainer.export(args.arch_path)
logger.info('Best architecture exported in %s', args.arch_path)
elif args.train_mode == 'retrain':
# this is retrain
from nni.nas.pytorch.fixed import apply_fixed_architecture
assert os.path.isfile(args.exported_arch_path), \
"exported_arch_path {} should be a file.".format(args.exported_arch_path)
apply_fixed_architecture(model, args.exported_arch_path, device=device)
trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300)
trainer.run()
\ No newline at end of file
import torch
import torch.nn as nn
import math
import ops
import putils
from nni.nas import pytorch as nas
class SearchMobileNet(nn.Module):
def __init__(self,
width_stages=[24,40,80,96,192,320],
n_cell_stages=[4,4,4,4,4,1],
stride_stages=[2,2,2,1,2,1],
width_mult=1, n_classes=1000,
dropout_rate=0, bn_param=(0.1, 1e-3)):
"""
Parameters
----------
width_stages: str
width (output channels) of each cell stage in the block
n_cell_stages: str
number of cells in each cell stage
stride_strages: str
stride of each cell stage in the block
width_mult : int
the scale factor of width
"""
super(SearchMobileNet, self).__init__()
input_channel = putils.make_divisible(32 * width_mult, 8)
first_cell_width = putils.make_divisible(16 * width_mult, 8)
for i in range(len(width_stages)):
width_stages[i] = putils.make_divisible(width_stages[i] * width_mult, 8)
# first conv
first_conv = ops.ConvLayer(3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='relu6', ops_order='weight_bn_act')
# first block
first_block_conv = ops.OPS['3x3_MBConv1'](input_channel, first_cell_width, 1)
first_block = first_block_conv
input_channel = first_cell_width
blocks = [first_block]
stage_cnt = 0
for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages):
for i in range(n_cell):
if i == 0:
stride = s
else:
stride = 1
op_candidates = [ops.OPS['3x3_MBConv3'](input_channel, width, stride),
ops.OPS['3x3_MBConv6'](input_channel, width, stride),
ops.OPS['5x5_MBConv3'](input_channel, width, stride),
ops.OPS['5x5_MBConv6'](input_channel, width, stride),
ops.OPS['7x7_MBConv3'](input_channel, width, stride),
ops.OPS['7x7_MBConv6'](input_channel, width, stride)]
if stride == 1 and input_channel == width:
# if it is not the first one
op_candidates += [ops.OPS['Zero'](input_channel, width, stride)]
conv_op = nas.mutables.LayerChoice(op_candidates,
return_mask=True,
key="s{}_c{}".format(stage_cnt, i))
else:
conv_op = nas.mutables.LayerChoice(op_candidates,
return_mask=True,
key="s{}_c{}".format(stage_cnt, i))
# shortcut
if stride == 1 and input_channel == width:
# if not first cell
shortcut = ops.IdentityLayer(input_channel, input_channel)
else:
shortcut = None
inverted_residual_block = ops.MobileInvertedResidualBlock(conv_op, shortcut, op_candidates)
blocks.append(inverted_residual_block)
input_channel = width
stage_cnt += 1
# feature mix layer
last_channel = putils.make_devisible(1280 * width_mult, 8) if width_mult > 1.0 else 1280
feature_mix_layer = ops.ConvLayer(input_channel, last_channel, kernel_size=1, use_bn=True, act_func='relu6', ops_order='weight_bn_act', )
classifier = ops.LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
self.first_conv = first_conv
self.blocks = nn.ModuleList(blocks)
self.feature_mix_layer = feature_mix_layer
self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = classifier
# set bn param
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
def forward(self, x):
x = self.first_conv(x)
for block in self.blocks:
x = block(x)
x = self.feature_mix_layer(x)
x = self.global_avg_pooling(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def set_bn_param(self, momentum, eps):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
m.momentum = momentum
m.eps = eps
return
def init_model(self, model_init='he_fout', init_div_groups=False):
for m in self.modules():
if isinstance(m, nn.Conv2d):
if model_init == 'he_fout':
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if init_div_groups:
n /= m.groups
m.weight.data.normal_(0, math.sqrt(2. / n))
elif model_init == 'he_fin':
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
if init_div_groups:
n /= m.groups
m.weight.data.normal_(0, math.sqrt(2. / n))
else:
raise NotImplementedError
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
stdv = 1. / math.sqrt(m.weight.size(1))
m.weight.data.uniform_(-stdv, stdv)
if m.bias is not None:
m.bias.data.zero_()
from collections import OrderedDict
import torch
import torch.nn as nn
from putils import get_same_padding, build_activation
OPS = {
'Identity': lambda in_C, out_C, stride: IdentityLayer(in_C, out_C, ops_order='weight_bn_act'),
'Zero': lambda in_C, out_C, stride: ZeroLayer(stride=stride),
'3x3_MBConv1': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 3, stride, 1),
'3x3_MBConv2': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 3, stride, 2),
'3x3_MBConv3': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 3, stride, 3),
'3x3_MBConv4': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 3, stride, 4),
'3x3_MBConv5': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 3, stride, 5),
'3x3_MBConv6': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 3, stride, 6),
'5x5_MBConv1': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 5, stride, 1),
'5x5_MBConv2': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 5, stride, 2),
'5x5_MBConv3': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 5, stride, 3),
'5x5_MBConv4': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 5, stride, 4),
'5x5_MBConv5': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 5, stride, 5),
'5x5_MBConv6': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 5, stride, 6),
'7x7_MBConv1': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 7, stride, 1),
'7x7_MBConv2': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 7, stride, 2),
'7x7_MBConv3': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 7, stride, 3),
'7x7_MBConv4': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 7, stride, 4),
'7x7_MBConv5': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 7, stride, 5),
'7x7_MBConv6': lambda in_C, out_C, stride: MBInvertedConvLayer(in_C, out_C, 7, stride, 6)
}
class MobileInvertedResidualBlock(nn.Module):
def __init__(self, mobile_inverted_conv, shortcut, op_candidates_list):
super(MobileInvertedResidualBlock, self).__init__()
self.mobile_inverted_conv = mobile_inverted_conv
self.shortcut = shortcut
self.op_candidates_list = op_candidates_list
def forward(self, x):
out, idx = self.mobile_inverted_conv(x)
# TODO: unify idx format
if not isinstance(idx, int):
idx = (idx == 1).nonzero()
if self.op_candidates_list[idx].is_zero_layer():
res = x
elif self.shortcut is None:
res = out
else:
conv_x = out
skip_x = self.shortcut(x)
res = skip_x + conv_x
return res
class ShuffleLayer(nn.Module):
def __init__(self, groups):
super(ShuffleLayer, self).__init__()
self.groups = groups
def forward(self, x):
batchsize, num_channels, height, width = x.size()
channels_per_group = num_channels // self.groups
# reshape
x = x.view(batchsize, self.groups, channels_per_group, height, width)
# noinspection PyUnresolvedReferences
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class Base2DLayer(nn.Module):
def __init__(self, in_channels, out_channels,
use_bn=True, act_func='relu', dropout_rate=0, ops_order='weight_bn_act'):
super(Base2DLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.use_bn = use_bn
self.act_func = act_func
self.dropout_rate = dropout_rate
self.ops_order = ops_order
""" modules """
modules = {}
# batch norm
if self.use_bn:
if self.bn_before_weight:
modules['bn'] = nn.BatchNorm2d(in_channels)
else:
modules['bn'] = nn.BatchNorm2d(out_channels)
else:
modules['bn'] = None
# activation
modules['act'] = build_activation(self.act_func, self.ops_list[0] != 'act')
# dropout
if self.dropout_rate > 0:
modules['dropout'] = nn.Dropout2d(self.dropout_rate, inplace=True)
else:
modules['dropout'] = None
# weight
modules['weight'] = self.weight_op()
# add modules
for op in self.ops_list:
if modules[op] is None:
continue
elif op == 'weight':
if modules['dropout'] is not None:
self.add_module('dropout', modules['dropout'])
for key in modules['weight']:
self.add_module(key, modules['weight'][key])
else:
self.add_module(op, modules[op])
@property
def ops_list(self):
return self.ops_order.split('_')
@property
def bn_before_weight(self):
for op in self.ops_list:
if op == 'bn':
return True
elif op == 'weight':
return False
raise ValueError('Invalid ops_order: %s' % self.ops_order)
def weight_op(self):
raise NotImplementedError
def forward(self, x):
for module in self._modules.values():
x = module(x)
return x
@staticmethod
def is_zero_layer():
return False
class ConvLayer(Base2DLayer):
def __init__(self, in_channels, out_channels,
kernel_size=3, stride=1, dilation=1, groups=1, bias=False, has_shuffle=False,
use_bn=True, act_func='relu', dropout_rate=0, ops_order='weight_bn_act'):
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.groups = groups
self.bias = bias
self.has_shuffle = has_shuffle
super(ConvLayer, self).__init__(in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order)
def weight_op(self):
padding = get_same_padding(self.kernel_size)
if isinstance(padding, int):
padding *= self.dilation
else:
padding[0] *= self.dilation
padding[1] *= self.dilation
weight_dict = OrderedDict()
weight_dict['conv'] = nn.Conv2d(
self.in_channels, self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=padding,
dilation=self.dilation, groups=self.groups, bias=self.bias
)
if self.has_shuffle and self.groups > 1:
weight_dict['shuffle'] = ShuffleLayer(self.groups)
return weight_dict
class IdentityLayer(Base2DLayer):
def __init__(self, in_channels, out_channels,
use_bn=False, act_func=None, dropout_rate=0, ops_order='weight_bn_act'):
super(IdentityLayer, self).__init__(in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order)
def weight_op(self):
return None
class LinearLayer(nn.Module):
def __init__(self, in_features, out_features, bias=True,
use_bn=False, act_func=None, dropout_rate=0, ops_order='weight_bn_act'):
super(LinearLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.bias = bias
self.use_bn = use_bn
self.act_func = act_func
self.dropout_rate = dropout_rate
self.ops_order = ops_order
""" modules """
modules = {}
# batch norm
if self.use_bn:
if self.bn_before_weight:
modules['bn'] = nn.BatchNorm1d(in_features)
else:
modules['bn'] = nn.BatchNorm1d(out_features)
else:
modules['bn'] = None
# activation
modules['act'] = build_activation(self.act_func, self.ops_list[0] != 'act')
# dropout
if self.dropout_rate > 0:
modules['dropout'] = nn.Dropout(self.dropout_rate, inplace=True)
else:
modules['dropout'] = None
# linear
modules['weight'] = {'linear': nn.Linear(self.in_features, self.out_features, self.bias)}
# add modules
for op in self.ops_list:
if modules[op] is None:
continue
elif op == 'weight':
if modules['dropout'] is not None:
self.add_module('dropout', modules['dropout'])
for key in modules['weight']:
self.add_module(key, modules['weight'][key])
else:
self.add_module(op, modules[op])
@property
def ops_list(self):
return self.ops_order.split('_')
@property
def bn_before_weight(self):
for op in self.ops_list:
if op == 'bn':
return True
elif op == 'weight':
return False
raise ValueError('Invalid ops_order: %s' % self.ops_order)
def forward(self, x):
for module in self._modules.values():
x = module(x)
return x
@staticmethod
def is_zero_layer():
return False
class MBInvertedConvLayer(nn.Module):
"""
This layer is introduced in section 4.2 in the paper https://arxiv.org/pdf/1812.00332.pdf
"""
def __init__(self, in_channels, out_channels,
kernel_size=3, stride=1, expand_ratio=6, mid_channels=None):
super(MBInvertedConvLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.expand_ratio = expand_ratio
self.mid_channels = mid_channels
if self.mid_channels is None:
feature_dim = round(self.in_channels * self.expand_ratio)
else:
feature_dim = self.mid_channels
if self.expand_ratio == 1:
self.inverted_bottleneck = None
else:
self.inverted_bottleneck = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
('bn', nn.BatchNorm2d(feature_dim)),
('act', nn.ReLU6(inplace=True)),
]))
pad = get_same_padding(self.kernel_size)
self.depth_conv = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(feature_dim, feature_dim, kernel_size, stride, pad, groups=feature_dim, bias=False)),
('bn', nn.BatchNorm2d(feature_dim)),
('act', nn.ReLU6(inplace=True)),
]))
self.point_linear = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
('bn', nn.BatchNorm2d(out_channels)),
]))
def forward(self, x):
if self.inverted_bottleneck:
x = self.inverted_bottleneck(x)
x = self.depth_conv(x)
x = self.point_linear(x)
return x
@staticmethod
def is_zero_layer():
return False
class ZeroLayer(nn.Module):
def __init__(self, stride):
super(ZeroLayer, self).__init__()
self.stride = stride
def forward(self, x):
'''n, c, h, w = x.size()
h //= self.stride
w //= self.stride
device = x.get_device() if x.is_cuda else torch.device('cpu')
# noinspection PyUnresolvedReferences
padding = torch.zeros(n, c, h, w, device=device, requires_grad=False)
return padding'''
return x * 0
@staticmethod
def is_zero_layer():
return True
import torch.nn as nn
def get_parameters(model, keys=None, mode='include'):
if keys is None:
for name, param in model.named_parameters():
yield param
elif mode == 'include':
for name, param in model.named_parameters():
flag = False
for key in keys:
if key in name:
flag = True
break
if flag:
yield param
elif mode == 'exclude':
for name, param in model.named_parameters():
flag = True
for key in keys:
if key in name:
flag = False
break
if flag:
yield param
else:
raise ValueError('do not support: %s' % mode)
def get_same_padding(kernel_size):
if isinstance(kernel_size, tuple):
assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
p1 = get_same_padding(kernel_size[0])
p2 = get_same_padding(kernel_size[1])
return p1, p2
assert isinstance(kernel_size, int), 'kernel size should be either `int` or `tuple`'
assert kernel_size % 2 > 0, 'kernel size should be odd number'
return kernel_size // 2
def build_activation(act_func, inplace=True):
if act_func == 'relu':
return nn.ReLU(inplace=inplace)
elif act_func == 'relu6':
return nn.ReLU6(inplace=inplace)
elif act_func == 'tanh':
return nn.Tanh()
elif act_func == 'sigmoid':
return nn.Sigmoid()
elif act_func is None:
return None
else:
raise ValueError('do not support: %s' % act_func)
def make_divisible(v, divisor, min_val=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_val is None:
min_val = divisor
new_v = max(min_val, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
import time
import math
from datetime import timedelta
import torch
from torch import nn as nn
from nni.nas.pytorch.utils import AverageMeter
def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
logsoftmax = nn.LogSoftmax()
n_classes = pred.size(1)
# convert to one-hot
target = torch.unsqueeze(target, 1)
soft_target = torch.zeros_like(pred)
soft_target.scatter_(1, target, 1)
# label smoothing
soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
def accuracy(output, target, topk=(1,)):
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class Retrain:
def __init__(self, model, optimizer, device, data_provider, n_epochs):
self.model = model
self.optimizer = optimizer
self.device = device
self.train_loader = data_provider.train
self.valid_loader = data_provider.valid
self.test_loader = data_provider.test
self.n_epochs = n_epochs
self.criterion = nn.CrossEntropyLoss()
def run(self):
self.model = torch.nn.DataParallel(self.model)
self.model.to(self.device)
# train
self.train()
# validate
self.validate(is_test=False)
# test
self.validate(is_test=True)
def train_one_epoch(self, adjust_lr_func, train_log_func, label_smoothing=0.1):
batch_time = AverageMeter('batch_time')
data_time = AverageMeter('data_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
self.model.train()
end = time.time()
for i, (images, labels) in enumerate(self.train_loader):
data_time.update(time.time() - end)
new_lr = adjust_lr_func(i)
images, labels = images.to(self.device), labels.to(self.device)
output = self.model(images)
if label_smoothing > 0:
loss = cross_entropy_with_label_smoothing(output, labels, label_smoothing)
else:
loss = self.criterion(output, labels)
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# compute gradient and do SGD step
self.model.zero_grad() # or self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0 or i + 1 == len(self.train_loader):
batch_log = train_log_func(i, batch_time, data_time, losses, top1, top5, new_lr)
print(batch_log)
return top1, top5
def train(self, validation_frequency=1):
best_acc = 0
nBatch = len(self.train_loader)
def train_log_func(epoch_, i, batch_time, data_time, losses, top1, top5, lr):
batch_log = 'Train [{0}][{1}/{2}]\t' \
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'. \
format(epoch_ + 1, i, nBatch - 1,
batch_time=batch_time, data_time=data_time, losses=losses, top1=top1)
batch_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(top5=top5)
batch_log += '\tlr {lr:.5f}'.format(lr=lr)
return batch_log
def adjust_learning_rate(n_epochs, optimizer, epoch, batch=0, nBatch=None):
""" adjust learning of a given optimizer and return the new learning rate """
# cosine
T_total = n_epochs * nBatch
T_cur = epoch * nBatch + batch
# init_lr = 0.05
new_lr = 0.5 * 0.05 * (1 + math.cos(math.pi * T_cur / T_total))
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
return new_lr
for epoch in range(self.n_epochs):
print('\n', '-' * 30, 'Train epoch: %d' % (epoch + 1), '-' * 30, '\n')
end = time.time()
train_top1, train_top5 = self.train_one_epoch(
lambda i: adjust_learning_rate(self.n_epochs, self.optimizer, epoch, i, nBatch),
lambda i, batch_time, data_time, losses, top1, top5, new_lr:
train_log_func(epoch, i, batch_time, data_time, losses, top1, top5, new_lr),
)
time_per_epoch = time.time() - end
seconds_left = int((self.n_epochs - epoch - 1) * time_per_epoch)
print('Time per epoch: %s, Est. complete in: %s' % (
str(timedelta(seconds=time_per_epoch)),
str(timedelta(seconds=seconds_left))))
if (epoch + 1) % validation_frequency == 0:
val_loss, val_acc, val_acc5 = self.validate(is_test=False)
is_best = val_acc > best_acc
best_acc = max(best_acc, val_acc)
val_log = 'Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f} ({4:.3f})'.\
format(epoch + 1, self.n_epochs, val_loss, val_acc, best_acc)
val_log += '\ttop-5 acc {0:.3f}\tTrain top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}'.\
format(val_acc5, top1=train_top1, top5=train_top5)
print(val_log)
else:
is_best = False
def validate(self, is_test=True):
if is_test:
data_loader = self.test_loader
else:
data_loader = self.valid_loader
self.model.eval()
batch_time = AverageMeter('batch_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
end = time.time()
with torch.no_grad():
for i, (images, labels) in enumerate(data_loader):
images, labels = images.to(self.device), labels.to(self.device)
# compute output
output = self.model(images)
loss = self.criterion(output, labels)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0 or i + 1 == len(data_loader):
if is_test:
prefix = 'Test'
else:
prefix = 'Valid'
test_log = prefix + ': [{0}/{1}]\t'\
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'\
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'\
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'.\
format(i, len(data_loader) - 1, batch_time=batch_time, loss=losses, top1=top1)
test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(top5=top5)
print(test_log)
return losses.avg, top1.avg, top5.avg
\ No newline at end of file
...@@ -64,6 +64,10 @@ class BaseMutator(nn.Module): ...@@ -64,6 +64,10 @@ class BaseMutator(nn.Module):
""" """
return self._structured_mutables return self._structured_mutables
@property
def undedup_mutables(self):
return self._structured_mutables.traverse(deduplicate=False)
def forward(self, *inputs): def forward(self, *inputs):
""" """
Warnings Warnings
......
from .mutator import ProxylessNasMutator
from .trainer import ProxylessNasTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import torch
from torch import nn as nn
from torch.nn import functional as F
import numpy as np
from nni.nas.pytorch.base_mutator import BaseMutator
from nni.nas.pytorch.mutables import LayerChoice
from .utils import detach_variable
class ArchGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
detached_x = detach_variable(x)
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
return grad_x[0], binary_grads, None, None
class MixedOp(nn.Module):
"""
This class is to instantiate and manage info of one LayerChoice.
It includes architecture weights, binary weights, and member functions
operating the weights.
forward_mode:
forward/backward mode for LayerChoice: None, two, full, and full_v2.
For training architecture weights, we use full_v2 by default, and for training
model weights, we use None.
"""
forward_mode = None
def __init__(self, mutable):
"""
Parameters
----------
mutable : LayerChoice
A LayerChoice in user model
"""
super(MixedOp, self).__init__()
self.ap_path_alpha = nn.Parameter(torch.Tensor(mutable.length))
self.ap_path_wb = nn.Parameter(torch.Tensor(mutable.length))
self.ap_path_alpha.requires_grad = False
self.ap_path_wb.requires_grad = False
self.active_index = [0]
self.inactive_index = None
self.log_prob = None
self.current_prob_over_ops = None
self.n_choices = mutable.length
def get_ap_path_alpha(self):
return self.ap_path_alpha
def to_requires_grad(self):
self.ap_path_alpha.requires_grad = True
self.ap_path_wb.requires_grad = True
def to_disable_grad(self):
self.ap_path_alpha.requires_grad = False
self.ap_path_wb.requires_grad = False
def forward(self, mutable, x):
"""
Define forward of LayerChoice. For 'full_v2', backward is also defined.
The 'two' mode is explained in section 3.2.1 in the paper.
The 'full_v2' mode is explained in Appendix D in the paper.
Parameters
----------
mutable : LayerChoice
this layer's mutable
x : tensor
inputs of this layer, only support one input
Returns
-------
output: tensor
output of this layer
"""
if MixedOp.forward_mode == 'full' or MixedOp.forward_mode == 'two':
output = 0
for _i in self.active_index:
oi = self.candidate_ops[_i](x)
output = output + self.ap_path_wb[_i] * oi
for _i in self.inactive_index:
oi = self.candidate_ops[_i](x)
output = output + self.ap_path_wb[_i] * oi.detach()
elif MixedOp.forward_mode == 'full_v2':
def run_function(key, candidate_ops, active_id):
def forward(_x):
return candidate_ops[active_id](_x)
return forward
def backward_function(key, candidate_ops, active_id, binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(candidate_ops)):
if k != active_id:
out_k = candidate_ops[k](_x.data)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
output = ArchGradientFunction.apply(
x, self.ap_path_wb, run_function(mutable.key, mutable.choices, self.active_index[0]),
backward_function(mutable.key, mutable.choices, self.active_index[0], self.ap_path_wb))
else:
output = self.active_op(mutable)(x)
return output
@property
def probs_over_ops(self):
"""
Apply softmax on alpha to generate probability distribution
Returns
-------
pytorch tensor
probability distribution
"""
probs = F.softmax(self.ap_path_alpha, dim=0) # softmax to probability
return probs
@property
def chosen_index(self):
"""
choose the op with max prob
Returns
-------
int
index of the chosen one
numpy.float32
prob of the chosen one
"""
probs = self.probs_over_ops.data.cpu().numpy()
index = int(np.argmax(probs))
return index, probs[index]
def active_op(self, mutable):
"""
assume only one path is active
Returns
-------
PyTorch module
the chosen operation
"""
return mutable.choices[self.active_index[0]]
@property
def active_op_index(self):
"""
return active op's index, the active op is sampled
Returns
-------
int
index of the active op
"""
return self.active_index[0]
def set_chosen_op_active(self):
"""
set chosen index, active and inactive indexes
"""
chosen_idx, _ = self.chosen_index
self.active_index = [chosen_idx]
self.inactive_index = [_i for _i in range(0, chosen_idx)] + \
[_i for _i in range(chosen_idx + 1, self.n_choices)]
def binarize(self, mutable):
"""
Sample based on alpha, and set binary weights accordingly.
ap_path_wb is set in this function, which is called binarize.
Parameters
----------
mutable : LayerChoice
this layer's mutable
"""
self.log_prob = None
# reset binary gates
self.ap_path_wb.data.zero_()
probs = self.probs_over_ops
if MixedOp.forward_mode == 'two':
# sample two ops according to probs
sample_op = torch.multinomial(probs.data, 2, replacement=False)
probs_slice = F.softmax(torch.stack([
self.ap_path_alpha[idx] for idx in sample_op
]), dim=0)
self.current_prob_over_ops = torch.zeros_like(probs)
for i, idx in enumerate(sample_op):
self.current_prob_over_ops[idx] = probs_slice[i]
# choose one to be active and the other to be inactive according to probs_slice
c = torch.multinomial(probs_slice.data, 1)[0] # 0 or 1
active_op = sample_op[c].item()
inactive_op = sample_op[1-c].item()
self.active_index = [active_op]
self.inactive_index = [inactive_op]
# set binary gate
self.ap_path_wb.data[active_op] = 1.0
else:
sample = torch.multinomial(probs, 1)[0].item()
self.active_index = [sample]
self.inactive_index = [_i for _i in range(0, sample)] + \
[_i for _i in range(sample + 1, len(mutable.choices))]
self.log_prob = torch.log(probs[sample])
self.current_prob_over_ops = probs
self.ap_path_wb.data[sample] = 1.0
# avoid over-regularization
for choice in mutable.choices:
for _, param in choice.named_parameters():
param.grad = None
@staticmethod
def delta_ij(i, j):
if i == j:
return 1
else:
return 0
def set_arch_param_grad(self, mutable):
"""
Calculate alpha gradient for this LayerChoice.
It is calculated using gradient of binary gate, probs of ops.
"""
binary_grads = self.ap_path_wb.grad.data
if self.active_op(mutable).is_zero_layer():
self.ap_path_alpha.grad = None
return
if self.ap_path_alpha.grad is None:
self.ap_path_alpha.grad = torch.zeros_like(self.ap_path_alpha.data)
if MixedOp.forward_mode == 'two':
involved_idx = self.active_index + self.inactive_index
probs_slice = F.softmax(torch.stack([
self.ap_path_alpha[idx] for idx in involved_idx
]), dim=0).data
for i in range(2):
for j in range(2):
origin_i = involved_idx[i]
origin_j = involved_idx[j]
self.ap_path_alpha.grad.data[origin_i] += \
binary_grads[origin_j] * probs_slice[j] * (MixedOp.delta_ij(i, j) - probs_slice[i])
for _i, idx in enumerate(self.active_index):
self.active_index[_i] = (idx, self.ap_path_alpha.data[idx].item())
for _i, idx in enumerate(self.inactive_index):
self.inactive_index[_i] = (idx, self.ap_path_alpha.data[idx].item())
else:
probs = self.probs_over_ops.data
for i in range(self.n_choices):
for j in range(self.n_choices):
self.ap_path_alpha.grad.data[i] += binary_grads[j] * probs[j] * (MixedOp.delta_ij(i, j) - probs[i])
return
def rescale_updated_arch_param(self):
"""
rescale architecture weights for the 'two' mode.
"""
if not isinstance(self.active_index[0], tuple):
assert self.active_op.is_zero_layer()
return
involved_idx = [idx for idx, _ in (self.active_index + self.inactive_index)]
old_alphas = [alpha for _, alpha in (self.active_index + self.inactive_index)]
new_alphas = [self.ap_path_alpha.data[idx] for idx in involved_idx]
offset = math.log(
sum([math.exp(alpha) for alpha in new_alphas]) / sum([math.exp(alpha) for alpha in old_alphas])
)
for idx in involved_idx:
self.ap_path_alpha.data[idx] -= offset
class ProxylessNasMutator(BaseMutator):
"""
This mutator initializes and operates all the LayerChoices of the input model.
It is for the corresponding trainer to control the training process of LayerChoices,
coordinating with whole training process.
"""
def __init__(self, model):
"""
Init a MixedOp instance for each mutable i.e., LayerChoice.
And register the instantiated MixedOp in corresponding LayerChoice.
If does not register it in LayerChoice, DataParallel does not work then,
because architecture weights are not included in the DataParallel model.
When MixedOPs are registered, we use ```requires_grad``` to control
whether calculate gradients of architecture weights.
Parameters
----------
model : pytorch model
The model that users want to tune, it includes search space defined with nni nas apis
"""
super(ProxylessNasMutator, self).__init__(model)
self._unused_modules = None
self.mutable_list = []
for mutable in self.undedup_mutables:
self.mutable_list.append(mutable)
mutable.registered_module = MixedOp(mutable)
def on_forward_layer_choice(self, mutable, *inputs):
"""
Callback of layer choice forward. This function defines the forward
logic of the input mutable. So mutable is only interface, its real
implementation is defined in mutator.
Parameters
----------
mutable: LayerChoice
forward logic of this input mutable
inputs: list of torch.Tensor
inputs of this mutable
Returns
-------
torch.Tensor
output of this mutable, i.e., LayerChoice
int
index of the chosen op
"""
# FIXME: return mask, to be consistent with other algorithms
idx = mutable.registered_module.active_op_index
return mutable.registered_module(mutable, *inputs), idx
def reset_binary_gates(self):
"""
For each LayerChoice, binarize binary weights
based on alpha to only activate one op.
It traverses all the mutables in the model to do this.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.binarize(mutable)
def set_chosen_op_active(self):
"""
For each LayerChoice, set the op with highest alpha as the chosen op.
Usually used for validation.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.set_chosen_op_active()
def num_arch_params(self):
"""
The number of mutables, i.e., LayerChoice
Returns
-------
int
the number of LayerChoice in user model
"""
return len(self.mutable_list)
def set_arch_param_grad(self):
"""
For each LayerChoice, calculate gradients for architecture weights, i.e., alpha
"""
for mutable in self.undedup_mutables:
mutable.registered_module.set_arch_param_grad(mutable)
def get_architecture_parameters(self):
"""
Get all the architecture parameters.
yield
-----
PyTorch Parameter
Return ap_path_alpha of the traversed mutable
"""
for mutable in self.undedup_mutables:
yield mutable.registered_module.get_ap_path_alpha()
def change_forward_mode(self, mode):
"""
Update forward mode of MixedOps, as training architecture weights and
model weights use different forward modes.
"""
MixedOp.forward_mode = mode
def get_forward_mode(self):
"""
Get forward mode of MixedOp
Returns
-------
string
the current forward mode of MixedOp
"""
return MixedOp.forward_mode
def rescale_updated_arch_param(self):
"""
Rescale architecture weights in 'two' mode.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.rescale_updated_arch_param()
def unused_modules_off(self):
"""
Remove unused modules for each mutables.
The removed modules are kept in ```self._unused_modules``` for resume later.
"""
self._unused_modules = []
for mutable in self.undedup_mutables:
mixed_op = mutable.registered_module
unused = {}
if self.get_forward_mode() in ['full', 'two', 'full_v2']:
involved_index = mixed_op.active_index + mixed_op.inactive_index
else:
involved_index = mixed_op.active_index
for i in range(mixed_op.n_choices):
if i not in involved_index:
unused[i] = mutable.choices[i]
mutable.choices[i] = None
self._unused_modules.append(unused)
def unused_modules_back(self):
"""
Resume the removed modules back.
"""
if self._unused_modules is None:
return
for m, unused in zip(self.mutable_list, self._unused_modules):
for i in unused:
m.choices[i] = unused[i]
self._unused_modules = None
def arch_requires_grad(self):
"""
Make architecture weights require gradient
"""
for mutable in self.undedup_mutables:
mutable.registered_module.to_requires_grad()
def arch_disable_grad(self):
"""
Disable gradient of architecture weights, i.e., does not
calcuate gradient for them.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.to_disable_grad()
def sample_final(self):
"""
Generate the final chosen architecture.
Returns
-------
dict
the choice of each mutable, i.e., LayerChoice
"""
result = dict()
for mutable in self.undedup_mutables:
assert isinstance(mutable, LayerChoice)
index, _ = mutable.registered_module.chosen_index
# pylint: disable=not-callable
result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=mutable.length).view(-1).bool()
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import time
import json
import logging
import torch
from torch import nn as nn
from nni.nas.pytorch.base_trainer import BaseTrainer
from nni.nas.pytorch.trainer import TorchTensorEncoder
from nni.nas.pytorch.utils import AverageMeter
from .mutator import ProxylessNasMutator
from .utils import cross_entropy_with_label_smoothing, accuracy
logger = logging.getLogger(__name__)
class ProxylessNasTrainer(BaseTrainer):
def __init__(self, model, model_optim, device,
train_loader, valid_loader, label_smoothing=0.1,
n_epochs=120, init_lr=0.025, binary_mode='full_v2',
arch_init_type='normal', arch_init_ratio=1e-3,
arch_optim_lr=1e-3, arch_weight_decay=0,
grad_update_arch_param_every=5, grad_update_steps=1,
warmup=True, warmup_epochs=25,
arch_valid_frequency=1,
load_ckpt=False, ckpt_path=None, arch_path=None):
"""
Parameters
----------
model : pytorch model
the user model, which has mutables
model_optim : pytorch optimizer
the user defined optimizer
device : pytorch device
the devices to train/search the model
train_loader : pytorch data loader
data loader for the training set
valid_loader : pytorch data loader
data loader for the validation set
label_smoothing : float
for label smoothing
n_epochs : int
number of epochs to train/search
init_lr : float
init learning rate for training the model
binary_mode : str
the forward/backward mode for the binary weights in mutator
arch_init_type : str
the way to init architecture parameters
arch_init_ratio : float
the ratio to init architecture parameters
arch_optim_lr : float
learning rate of the architecture parameters optimizer
arch_weight_decay : float
weight decay of the architecture parameters optimizer
grad_update_arch_param_every : int
update architecture weights every this number of minibatches
grad_update_steps : int
during each update of architecture weights, the number of steps to train
warmup : bool
whether to do warmup
warmup_epochs : int
the number of epochs to do during warmup
arch_valid_frequency : int
frequency of printing validation result
load_ckpt : bool
whether load checkpoint
ckpt_path : str
checkpoint path, if load_ckpt is True, ckpt_path cannot be None
arch_path : str
the path to store chosen architecture
"""
self.model = model
self.model_optim = model_optim
self.train_loader = train_loader
self.valid_loader = valid_loader
self.device = device
self.n_epochs = n_epochs
self.init_lr = init_lr
self.warmup = warmup
self.warmup_epochs = warmup_epochs
self.arch_valid_frequency = arch_valid_frequency
self.label_smoothing = label_smoothing
self.train_batch_size = train_loader.batch_sampler.batch_size
self.valid_batch_size = valid_loader.batch_sampler.batch_size
# update architecture parameters every this number of minibatches
self.grad_update_arch_param_every = grad_update_arch_param_every
# the number of steps per architecture parameter update
self.grad_update_steps = grad_update_steps
self.binary_mode = binary_mode
self.load_ckpt = load_ckpt
self.ckpt_path = ckpt_path
self.arch_path = arch_path
# init mutator
self.mutator = ProxylessNasMutator(model)
# DataParallel should be put behind the init of mutator
self.model = torch.nn.DataParallel(self.model)
self.model.to(self.device)
# iter of valid dataset for training architecture weights
self._valid_iter = None
# init architecture weights
self._init_arch_params(arch_init_type, arch_init_ratio)
# build architecture optimizer
self.arch_optimizer = torch.optim.Adam(self.mutator.get_architecture_parameters(),
arch_optim_lr,
weight_decay=arch_weight_decay,
betas=(0, 0.999),
eps=1e-8)
self.criterion = nn.CrossEntropyLoss()
self.warmup_curr_epoch = 0
self.train_curr_epoch = 0
def _init_arch_params(self, init_type='normal', init_ratio=1e-3):
"""
Initialize architecture weights
"""
for param in self.mutator.get_architecture_parameters():
if init_type == 'normal':
param.data.normal_(0, init_ratio)
elif init_type == 'uniform':
param.data.uniform_(-init_ratio, init_ratio)
else:
raise NotImplementedError
def _validate(self):
"""
Do validation. During validation, LayerChoices use the chosen active op.
Returns
-------
float, float, float
average loss, average top1 accuracy, average top5 accuracy
"""
self.valid_loader.batch_sampler.batch_size = self.valid_batch_size
self.valid_loader.batch_sampler.drop_last = False
self.mutator.set_chosen_op_active()
# remove unused modules to save memory
self.mutator.unused_modules_off()
# test on validation set under train mode
self.model.train()
batch_time = AverageMeter('batch_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
end = time.time()
with torch.no_grad():
for i, (images, labels) in enumerate(self.valid_loader):
images, labels = images.to(self.device), labels.to(self.device)
output = self.model(images)
loss = self.criterion(output, labels)
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0 or i + 1 == len(self.valid_loader):
test_log = 'Valid' + ': [{0}/{1}]\t'\
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'\
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'\
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'.\
format(i, len(self.valid_loader) - 1, batch_time=batch_time, loss=losses, top1=top1)
# return top5:
test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(top5=top5)
logger.info(test_log)
self.mutator.unused_modules_back()
return losses.avg, top1.avg, top5.avg
def _warm_up(self):
"""
Warm up the model, during warm up, architecture weights are not trained.
"""
lr_max = 0.05
data_loader = self.train_loader
nBatch = len(data_loader)
T_total = self.warmup_epochs * nBatch # total num of batches
for epoch in range(self.warmup_curr_epoch, self.warmup_epochs):
logger.info('\n--------Warmup epoch: %d--------\n', epoch + 1)
batch_time = AverageMeter('batch_time')
data_time = AverageMeter('data_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
# switch to train mode
self.model.train()
end = time.time()
logger.info('warm_up epoch: %d', epoch)
for i, (images, labels) in enumerate(data_loader):
data_time.update(time.time() - end)
# lr
T_cur = epoch * nBatch + i
warmup_lr = 0.5 * lr_max * (1 + math.cos(math.pi * T_cur / T_total))
for param_group in self.model_optim.param_groups:
param_group['lr'] = warmup_lr
images, labels = images.to(self.device), labels.to(self.device)
# compute output
self.mutator.reset_binary_gates() # random sample binary gates
self.mutator.unused_modules_off() # remove unused module for speedup
output = self.model(images)
if self.label_smoothing > 0:
loss = cross_entropy_with_label_smoothing(output, labels, self.label_smoothing)
else:
loss = self.criterion(output, labels)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# compute gradient and do SGD step
self.model.zero_grad()
loss.backward()
self.model_optim.step()
# unused modules back
self.mutator.unused_modules_back()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0 or i + 1 == nBatch:
batch_log = 'Warmup Train [{0}][{1}/{2}]\t' \
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
losses=losses, top1=top1, top5=top5, lr=warmup_lr)
logger.info(batch_log)
val_loss, val_top1, val_top5 = self._validate()
val_log = 'Warmup Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc {4:.3f}\t' \
'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}M'. \
format(epoch + 1, self.warmup_epochs, val_loss, val_top1, val_top5, top1=top1, top5=top5)
logger.info(val_log)
self.save_checkpoint()
self.warmup_curr_epoch += 1
def _get_update_schedule(self, nBatch):
"""
Generate schedule for training architecture weights. Key means after which minibatch
to update architecture weights, value means how many steps for the update.
Parameters
----------
nBatch : int
the total number of minibatches in one epoch
Returns
-------
dict
the schedule for updating architecture weights
"""
schedule = {}
for i in range(nBatch):
if (i + 1) % self.grad_update_arch_param_every == 0:
schedule[i] = self.grad_update_steps
return schedule
def _calc_learning_rate(self, epoch, batch=0, nBatch=None):
"""
Update learning rate.
"""
T_total = self.n_epochs * nBatch
T_cur = epoch * nBatch + batch
lr = 0.5 * self.init_lr * (1 + math.cos(math.pi * T_cur / T_total))
return lr
def _adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None):
"""
Adjust learning of a given optimizer and return the new learning rate
Parameters
----------
optimizer : pytorch optimizer
the used optimizer
epoch : int
the current epoch number
batch : int
the current minibatch
nBatch : int
the total number of minibatches in one epoch
Returns
-------
float
the adjusted learning rate
"""
new_lr = self._calc_learning_rate(epoch, batch, nBatch)
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
return new_lr
def _train(self):
"""
Train the model, it trains model weights and architecute weights.
Architecture weights are trained according to the schedule.
Before updating architecture weights, ```requires_grad``` is enabled.
Then, it is disabled after the updating, in order not to update
architecture weights when training model weights.
"""
nBatch = len(self.train_loader)
arch_param_num = self.mutator.num_arch_params()
binary_gates_num = self.mutator.num_arch_params()
logger.info('#arch_params: %d\t#binary_gates: %d', arch_param_num, binary_gates_num)
update_schedule = self._get_update_schedule(nBatch)
for epoch in range(self.train_curr_epoch, self.n_epochs):
logger.info('\n--------Train epoch: %d--------\n', epoch + 1)
batch_time = AverageMeter('batch_time')
data_time = AverageMeter('data_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
# switch to train mode
self.model.train()
end = time.time()
for i, (images, labels) in enumerate(self.train_loader):
data_time.update(time.time() - end)
lr = self._adjust_learning_rate(self.model_optim, epoch, batch=i, nBatch=nBatch)
# train weight parameters
images, labels = images.to(self.device), labels.to(self.device)
self.mutator.reset_binary_gates()
self.mutator.unused_modules_off()
output = self.model(images)
if self.label_smoothing > 0:
loss = cross_entropy_with_label_smoothing(output, labels, self.label_smoothing)
else:
loss = self.criterion(output, labels)
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
self.model.zero_grad()
loss.backward()
self.model_optim.step()
self.mutator.unused_modules_back()
if epoch > 0:
for _ in range(update_schedule.get(i, 0)):
start_time = time.time()
# GradientArchSearchConfig
self.mutator.arch_requires_grad()
arch_loss, exp_value = self._gradient_step()
self.mutator.arch_disable_grad()
used_time = time.time() - start_time
log_str = 'Architecture [%d-%d]\t Time %.4f\t Loss %.4f\t null %s' % \
(epoch + 1, i, used_time, arch_loss, exp_value)
logger.info(log_str)
batch_time.update(time.time() - end)
end = time.time()
# training log
if i % 10 == 0 or i + 1 == nBatch:
batch_log = 'Train [{0}][{1}/{2}]\t' \
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t' \
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
losses=losses, top1=top1, top5=top5, lr=lr)
logger.info(batch_log)
# validate
if (epoch + 1) % self.arch_valid_frequency == 0:
val_loss, val_top1, val_top5 = self._validate()
val_log = 'Valid [{0}]\tloss {1:.3f}\ttop-1 acc {2:.3f} \ttop-5 acc {3:.3f}\t' \
'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}'. \
format(epoch + 1, val_loss, val_top1, val_top5, top1=top1, top5=top5)
logger.info(val_log)
self.save_checkpoint()
self.train_curr_epoch += 1
def _valid_next_batch(self):
"""
Get next one minibatch from validation set
Returns
-------
(tensor, tensor)
the tuple of images and labels
"""
if self._valid_iter is None:
self._valid_iter = iter(self.valid_loader)
try:
data = next(self._valid_iter)
except StopIteration:
self._valid_iter = iter(self.valid_loader)
data = next(self._valid_iter)
return data
def _gradient_step(self):
"""
This gradient step is for updating architecture weights.
Mutator is intensively used in this function to operate on
architecture weights.
Returns
-------
float, None
loss of the model, None
"""
# use the same batch size as train batch size for architecture weights
self.valid_loader.batch_sampler.batch_size = self.train_batch_size
self.valid_loader.batch_sampler.drop_last = True
self.model.train()
self.mutator.change_forward_mode(self.binary_mode)
time1 = time.time() # time
# sample a batch of data from validation set
images, labels = self._valid_next_batch()
images, labels = images.to(self.device), labels.to(self.device)
time2 = time.time() # time
self.mutator.reset_binary_gates()
self.mutator.unused_modules_off()
output = self.model(images)
time3 = time.time()
ce_loss = self.criterion(output, labels)
expected_value = None
loss = ce_loss
self.model.zero_grad()
loss.backward()
self.mutator.set_arch_param_grad()
self.arch_optimizer.step()
if self.mutator.get_forward_mode() == 'two':
self.mutator.rescale_updated_arch_param()
self.mutator.unused_modules_back()
self.mutator.change_forward_mode(None)
time4 = time.time()
logger.info('(%.4f, %.4f, %.4f)', time2 - time1, time3 - time2, time4 - time3)
return loss.data.item(), expected_value.item() if expected_value is not None else None
def save_checkpoint(self):
"""
Save checkpoint of the whole model. Saving model weights and architecture weights in
```ckpt_path```, and saving currently chosen architecture in ```arch_path```.
"""
if self.ckpt_path:
state = {
'warmup_curr_epoch': self.warmup_curr_epoch,
'train_curr_epoch': self.train_curr_epoch,
'model': self.model.state_dict(),
'optim': self.model_optim.state_dict(),
'arch_optim': self.arch_optimizer.state_dict()
}
torch.save(state, self.ckpt_path)
if self.arch_path:
self.export(self.arch_path)
def load_checkpoint(self):
"""
Load the checkpoint from ```ckpt_path```.
"""
assert self.ckpt_path is not None, "If load_ckpt is not None, ckpt_path should not be None"
ckpt = torch.load(self.ckpt_path)
self.warmup_curr_epoch = ckpt['warmup_curr_epoch']
self.train_curr_epoch = ckpt['train_curr_epoch']
self.model.load_state_dict(ckpt['model'])
self.model_optim.load_state_dict(ckpt['optim'])
self.arch_optimizer.load_state_dict(ckpt['arch_optim'])
def train(self):
"""
Train the whole model.
"""
if self.load_ckpt:
self.load_checkpoint()
if self.warmup:
self._warm_up()
self._train()
def export(self, file_name):
"""
Export the chosen architecture into a file
Parameters
----------
file_name : str
the file that stores exported chosen architecture
"""
exported_arch = self.mutator.sample_final()
with open(file_name, 'w') as f:
json.dump(exported_arch, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def validate(self):
raise NotImplementedError
def checkpoint(self):
raise NotImplementedError
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
def detach_variable(inputs):
"""
Detach variables
Parameters
----------
inputs : pytorch tensors
pytorch tensors
"""
if isinstance(inputs, tuple):
return tuple([detach_variable(x) for x in inputs])
else:
x = inputs.detach()
x.requires_grad = inputs.requires_grad
return x
def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
"""
Parameters
----------
pred : pytorch tensor
predicted value
target : pytorch tensor
label
label_smoothing : float
the degree of label smoothing
Returns
-------
pytorch tensor
cross entropy
"""
logsoftmax = nn.LogSoftmax()
n_classes = pred.size(1)
# convert to one-hot
target = torch.unsqueeze(target, 1)
soft_target = torch.zeros_like(pred)
soft_target.scatter_(1, target, 1)
# label smoothing
soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
def accuracy(output, target, topk=(1,)):
"""
Computes the precision@k for the specified values of k
Parameters
----------
output : pytorch tensor
output, e.g., predicted value
target : pytorch tensor
label
topk : tuple
specify top1 and top5
Returns
-------
list
accuracy of top1 and top5
"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment