Unverified Commit 6f256c78 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Single Path One Shot (#1849)

parent 4f3ee9cb
data data
checkpoints checkpoints
runs runs
nni_auto_gen_search_space.json
# Single Path One-Shot Neural Architecture Search with Uniform Sampling
Single Path One-Shot by Megvii Research. [Paper link](https://arxiv.org/abs/1904.00420). [Official repo](https://github.com/megvii-model/SinglePathOneShot).
Block search only. Channel search is not supported yet.
Only GPU version is provided here.
## Preparation
### Requirements
* PyTorch >= 1.2
* NVIDIA DALI >= 0.16 as we use DALI to accelerate the data loading of ImageNet. [Installation guide](https://docs.nvidia.com/deeplearning/sdk/dali-developer-guide/docs/installation.html)
### Data
Need to download the flops lookup table from [here](https://1drv.ms/u/s!Am_mmG2-KsrnajesvSdfsq_cN48?e=aHVppN).
Put `op_flops_dict.pkl` and `checkpoint-150000.pth.tar` (if you don't want to retrain the supernet) under `data` directory.
Prepare ImageNet in the standard format (follow the script [here](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4)). Link it to `data/imagenet` will be more convenient.
After preparation, it's expected to have the following code structure:
```
spos
├── architecture_final.json
├── blocks.py
├── config_search.yml
├── data
│   ├── imagenet
│   │   ├── train
│   │   └── val
│   └── op_flops_dict.pkl
├── dataloader.py
├── network.py
├── readme.md
├── scratch.py
├── supernet.py
├── tester.py
├── tuner.py
└── utils.py
```
## Step 1. Train Supernet
```
python supernet.py
```
Will export the checkpoint to checkpoints directory, for the next step.
NOTE: The data loading used in the official repo is [slightly different from usual](https://github.com/megvii-model/SinglePathOneShot/issues/5), as they use BGR tensor and keep the values between 0 and 255 intentionally to align with their own DL framework. The option `--spos-preprocessing` will simulate the behavior used originally and enable you to use the checkpoints pretrained.
## Step 2. Evolution Search
Single Path One-Shot leverages evolution algorithm to search for the best architecture. The tester, which is responsible for testing the sampled architecture, recalculates all the batch norm for a subset of training images, and evaluates the architecture on the full validation set.
To have a search space ready for NNI framework, first run
```
nnictl ss_gen -t "python tester.py"
```
This will generate a file called `nni_auto_gen_search_space.json`, which is a serialized representation of your search space.
Then search with evolution tuner.
```
nnictl create --config config_search.yml
```
The final architecture exported from every epoch of evolution can be found in `checkpoints` under the working directory of your tuner, which, by default, is `$HOME/nni/experiments/your_experiment_id/log`.
## Step 3. Train from Scratch
```
python scratch.py
```
By default, it will use `architecture_final.json`. This architecture is provided by the official repo (converted into NNI format). You can use any architecture (e.g., the architecture found in step 2) with `--fixed-arc` option.
## Current Reproduction Results
Reproduction is still undergoing. Due to the gap between official release and original paper, we compare our current results with official repo (our run) and paper.
* Evolution phase is almost aligned with official repo. Our evolution algorithm shows a converging trend and reaches ~65% accuracy at the end of search. Nevertheless, this result is not on par with paper. For details, please refer to [this issue](https://github.com/megvii-model/SinglePathOneShot/issues/6).
* Retrain phase is not aligned. Our retraining code, which uses the architecture released by the authors, reaches 72.14% accuracy, still having a gap towards 73.61% by official release and 74.3% reported in original paper.
{
"LayerChoice1": [false, false, true, false],
"LayerChoice2": [false, true, false, false],
"LayerChoice3": [true, false, false, false],
"LayerChoice4": [false, true, false, false],
"LayerChoice5": [false, false, true, false],
"LayerChoice6": [true, false, false, false],
"LayerChoice7": [false, false, true, false],
"LayerChoice8": [true, false, false, false],
"LayerChoice9": [false, false, true, false],
"LayerChoice10": [true, false, false, false],
"LayerChoice11": [false, false, true, false],
"LayerChoice12": [false, false, false, true],
"LayerChoice13": [true, false, false, false],
"LayerChoice14": [true, false, false, false],
"LayerChoice15": [true, false, false, false],
"LayerChoice16": [true, false, false, false],
"LayerChoice17": [false, false, false, true],
"LayerChoice18": [false, false, true, false],
"LayerChoice19": [false, false, false, true],
"LayerChoice20": [false, false, false, true]
}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
class ShuffleNetBlock(nn.Module):
"""
When stride = 1, the block receives input with 2 * inp channels. Otherwise inp channels.
"""
def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp"):
super().__init__()
assert stride in [1, 2]
assert ksize in [3, 5, 7]
self.channels = inp // 2 if stride == 1 else inp
self.inp = inp
self.oup = oup
self.mid_channels = mid_channels
self.ksize = ksize
self.stride = stride
self.pad = ksize // 2
self.oup_main = oup - self.channels
assert self.oup_main > 0
self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence))
if stride == 2:
self.branch_proj = nn.Sequential(
# dw
nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad,
groups=self.channels, bias=False),
nn.BatchNorm2d(self.channels, affine=False),
# pw-linear
nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.channels, affine=False),
nn.ReLU(inplace=True)
)
def forward(self, x):
if self.stride == 2:
x_proj, x = self.branch_proj(x), x
else:
x_proj, x = self._channel_shuffle(x)
return torch.cat((x_proj, self.branch_main(x)), 1)
def _decode_point_depth_conv(self, sequence):
result = []
first_depth = first_point = True
pc = c = self.channels
for i, token in enumerate(sequence):
# compute output channels of this conv
if i + 1 == len(sequence):
assert token == "p", "Last conv must be point-wise conv."
c = self.oup_main
elif token == "p" and first_point:
c = self.mid_channels
if token == "d":
# depth-wise conv
assert pc == c, "Depth-wise conv must not change channels."
result.append(nn.Conv2d(pc, c, self.ksize, self.stride if first_depth else 1, self.pad,
groups=c, bias=False))
result.append(nn.BatchNorm2d(c, affine=False))
first_depth = False
elif token == "p":
# point-wise conv
result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False))
result.append(nn.BatchNorm2d(c, affine=False))
result.append(nn.ReLU(inplace=True))
first_point = False
else:
raise ValueError("Conv sequence must be d and p.")
pc = c
return result
def _channel_shuffle(self, x):
bs, num_channels, height, width = x.data.size()
assert (num_channels % 4 == 0)
x = x.reshape(bs * num_channels // 2, 2, height * width)
x = x.permute(1, 0, 2)
x = x.reshape(2, -1, num_channels // 2, height, width)
return x[0], x[1]
class ShuffleXceptionBlock(ShuffleNetBlock):
def __init__(self, inp, oup, mid_channels, stride):
super().__init__(inp, oup, mid_channels, 3, stride, "dpdpdp")
authorName: unknown
experimentName: SPOS Search
trialConcurrency: 4
maxExecDuration: 7d
maxTrialNum: 99999
trainingServicePlatform: local
searchSpacePath: nni_auto_gen_search_space.json
useAnnotation: false
tuner:
codeDir: .
classFileName: tuner.py
className: EvolutionWithFlops
trial:
command: python tester.py --imagenet-dir /path/to/your/imagenet --spos-prep
codeDir: .
gpuNum: 1
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import torch.utils.data
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
class HybridTrainPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed=12, local_rank=0, world_size=1,
spos_pre=False):
super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id)
color_space_type = types.BGR if spos_pre else types.RGB
self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True)
self.decode = ops.ImageDecoder(device="mixed", output_type=color_space_type)
self.res = ops.RandomResizedCrop(device="gpu", size=crop,
interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR)
self.twist = ops.ColorTwist(device="gpu")
self.jitter_rng = ops.Uniform(range=[0.6, 1.4])
self.cmnp = ops.CropMirrorNormalize(device="gpu",
output_dtype=types.FLOAT,
output_layout=types.NCHW,
image_type=color_space_type,
mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255],
std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255])
self.coin = ops.CoinFlip(probability=0.5)
def define_graph(self):
rng = self.coin()
self.jpegs, self.labels = self.input(name="Reader")
images = self.decode(self.jpegs)
images = self.res(images)
images = self.twist(images, saturation=self.jitter_rng(),
contrast=self.jitter_rng(), brightness=self.jitter_rng())
output = self.cmnp(images, mirror=rng)
return [output, self.labels]
class HybridValPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, seed=12, local_rank=0, world_size=1,
spos_pre=False, shuffle=False):
super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id)
color_space_type = types.BGR if spos_pre else types.RGB
self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size,
random_shuffle=shuffle)
self.decode = ops.ImageDecoder(device="mixed", output_type=color_space_type)
self.res = ops.Resize(device="gpu", resize_shorter=size,
interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device="gpu",
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=color_space_type,
mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255],
std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255])
def define_graph(self):
self.jpegs, self.labels = self.input(name="Reader")
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images)
return [output, self.labels]
class ClassificationWrapper:
def __init__(self, loader, size):
self.loader = loader
self.size = size
def __iter__(self):
return self
def __next__(self):
data = next(self.loader)
return data[0]["data"], data[0]["label"].view(-1).long().cuda(non_blocking=True)
def __len__(self):
return self.size
def get_imagenet_iter_dali(split, image_dir, batch_size, num_threads, crop=224, val_size=256,
spos_preprocessing=False, seed=12, shuffle=False, device_id=None):
world_size, local_rank = 1, 0
if device_id is None:
device_id = torch.cuda.device_count() - 1 # use last gpu
if split == "train":
pipeline = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id,
data_dir=os.path.join(image_dir, "train"), seed=seed,
crop=crop, world_size=world_size, local_rank=local_rank,
spos_pre=spos_preprocessing)
elif split == "val":
pipeline = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id,
data_dir=os.path.join(image_dir, "val"), seed=seed,
crop=crop, size=val_size, world_size=world_size, local_rank=local_rank,
spos_pre=spos_preprocessing, shuffle=shuffle)
else:
raise AssertionError
pipeline.build()
num_samples = pipeline.epoch_size("Reader")
return ClassificationWrapper(
DALIClassificationIterator(pipeline, size=num_samples, fill_last_batch=split == "train",
auto_reset=True), (num_samples + batch_size - 1) // batch_size)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import pickle
import re
import torch
import torch.nn as nn
from nni.nas.pytorch import mutables
from blocks import ShuffleNetBlock, ShuffleXceptionBlock
class ShuffleNetV2OneShot(nn.Module):
block_keys = [
'shufflenet_3x3',
'shufflenet_5x5',
'shufflenet_7x7',
'xception_3x3',
]
def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=1024, n_classes=1000,
op_flops_path="./data/op_flops_dict.pkl"):
super().__init__()
assert input_size % 32 == 0
with open(os.path.join(os.path.dirname(__file__), op_flops_path), "rb") as fp:
self._op_flops_dict = pickle.load(fp)
self.stage_blocks = [4, 4, 8, 4]
self.stage_channels = [64, 160, 320, 640]
self._parsed_flops = dict()
self._input_size = input_size
self._feature_map_size = input_size
self._first_conv_channels = first_conv_channels
self._last_conv_channels = last_conv_channels
self._n_classes = n_classes
# building first layer
self.first_conv = nn.Sequential(
nn.Conv2d(3, first_conv_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(first_conv_channels, affine=False),
nn.ReLU(inplace=True),
)
self._feature_map_size //= 2
p_channels = first_conv_channels
features = []
for num_blocks, channels in zip(self.stage_blocks, self.stage_channels):
features.extend(self._make_blocks(num_blocks, p_channels, channels))
p_channels = channels
self.features = nn.Sequential(*features)
self.conv_last = nn.Sequential(
nn.Conv2d(p_channels, last_conv_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(last_conv_channels, affine=False),
nn.ReLU(inplace=True),
)
self.globalpool = nn.AvgPool2d(self._feature_map_size)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Sequential(
nn.Linear(last_conv_channels, n_classes, bias=False),
)
self._initialize_weights()
def _make_blocks(self, blocks, in_channels, channels):
result = []
for i in range(blocks):
stride = 2 if i == 0 else 1
inp = in_channels if i == 0 else channels
oup = channels
base_mid_channels = channels // 2
mid_channels = int(base_mid_channels) # prepare for scale
choice_block = mutables.LayerChoice([
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride),
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride),
ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride),
ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride)
])
result.append(choice_block)
# find the corresponding flops
flop_key = (inp, oup, mid_channels, self._feature_map_size, self._feature_map_size, stride)
self._parsed_flops[choice_block.key] = [
self._op_flops_dict["{}_stride_{}".format(k, stride)][flop_key] for k in self.block_keys
]
if stride == 2:
self._feature_map_size //= 2
return result
def forward(self, x):
bs = x.size(0)
x = self.first_conv(x)
x = self.features(x)
x = self.conv_last(x)
x = self.globalpool(x)
x = self.dropout(x)
x = x.contiguous().view(bs, -1)
x = self.classifier(x)
return x
def get_candidate_flops(self, candidate):
conv1_flops = self._op_flops_dict["conv1"][(3, self._first_conv_channels,
self._input_size, self._input_size, 2)]
# Should use `last_conv_channels` here, but megvii insists that it's `n_classes`. Keeping it.
# https://github.com/megvii-model/SinglePathOneShot/blob/36eed6cf083497ffa9cfe7b8da25bb0b6ba5a452/src/Supernet/flops.py#L313
rest_flops = self._op_flops_dict["rest_operation"][(self.stage_channels[-1], self._n_classes,
self._feature_map_size, self._feature_map_size, 1)]
total_flops = conv1_flops + rest_flops
for k, m in candidate.items():
parsed_flops_dict = self._parsed_flops[k]
if isinstance(m, dict): # to be compatible with classical nas format
total_flops += parsed_flops_dict[m["_idx"]]
else:
total_flops += parsed_flops_dict[torch.max(m, 0)[1]]
return total_flops
def _initialize_weights(self):
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'first' in name:
nn.init.normal_(m.weight, 0, 0.01)
else:
nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0001)
nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0001)
nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"):
checkpoint = torch.load(filepath, map_location=torch.device("cpu"))
result = dict()
for k, v in checkpoint["state_dict"].items():
if k.startswith("module."):
k = k[len("module."):]
k = re.sub(r"^(features.\d+).(\d+)", "\\1.choices.\\2", k)
result[k] = v
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import logging
import random
import numpy as np
import torch
import torch.nn as nn
from dataloader import get_imagenet_iter_dali
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeterGroup
from torch.utils.tensorboard import SummaryWriter
from network import ShuffleNetV2OneShot
from utils import CrossEntropyLabelSmooth, accuracy
logger = logging.getLogger("nni.spos.scratch")
def train(epoch, model, criterion, optimizer, loader, writer, args):
model.train()
meters = AverageMeterGroup()
cur_lr = optimizer.param_groups[0]["lr"]
for step, (x, y) in enumerate(loader):
cur_step = len(loader) * epoch + step
optimizer.zero_grad()
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()
metrics = accuracy(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
writer.add_scalar("lr", cur_lr, global_step=cur_step)
writer.add_scalar("loss/train", loss.item(), global_step=cur_step)
writer.add_scalar("acc1/train", metrics["acc1"], global_step=cur_step)
writer.add_scalar("acc5/train", metrics["acc5"], global_step=cur_step)
if step % args.log_frequency == 0 or step + 1 == len(loader):
logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1,
args.epochs, step + 1, len(loader), meters)
logger.info("Epoch %d training summary: %s", epoch + 1, meters)
def validate(epoch, model, criterion, loader, writer, args):
model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(loader):
logits = model(x)
loss = criterion(logits, y)
metrics = accuracy(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if step % args.log_frequency == 0 or step + 1 == len(loader):
logger.info("Epoch [%d/%d] Validation Step [%d/%d] %s", epoch + 1,
args.epochs, step + 1, len(loader), meters)
writer.add_scalar("loss/test", meters.loss.avg, global_step=epoch)
writer.add_scalar("acc1/test", meters.acc1.avg, global_step=epoch)
writer.add_scalar("acc5/test", meters.acc5.avg, global_step=epoch)
logger.info("Epoch %d validation: top1 = %f, top5 = %f", epoch + 1, meters.acc1.avg, meters.acc5.avg)
if __name__ == "__main__":
parser = argparse.ArgumentParser("SPOS Training From Scratch")
parser.add_argument("--imagenet-dir", type=str, default="./data/imagenet")
parser.add_argument("--tb-dir", type=str, default="runs")
parser.add_argument("--architecture", type=str, default="architecture_final.json")
parser.add_argument("--workers", type=int, default=12)
parser.add_argument("--batch-size", type=int, default=1024)
parser.add_argument("--epochs", type=int, default=240)
parser.add_argument("--learning-rate", type=float, default=0.5)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--weight-decay", type=float, default=4E-5)
parser.add_argument("--label-smooth", type=float, default=0.1)
parser.add_argument("--log-frequency", type=int, default=10)
parser.add_argument("--lr-decay", type=str, default="linear")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--spos-preprocessing", default=False, action="store_true")
parser.add_argument("--label-smoothing", type=float, default=0.1)
args = parser.parse_args()
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
torch.backends.cudnn.deterministic = True
model = ShuffleNetV2OneShot()
model.cuda()
apply_fixed_architecture(model, args.architecture)
if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu
model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1)))
criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing)
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate,
momentum=args.momentum, weight_decay=args.weight_decay)
if args.lr_decay == "linear":
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
lambda step: (1.0 - step / args.epochs)
if step <= args.epochs else 0,
last_epoch=-1)
elif args.lr_decay == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, 1E-3)
else:
raise ValueError("'%s' not supported." % args.lr_decay)
writer = SummaryWriter(log_dir=args.tb_dir)
train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.batch_size, args.workers,
spos_preprocessing=args.spos_preprocessing)
val_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.batch_size, args.workers,
spos_preprocessing=args.spos_preprocessing)
for epoch in range(args.epochs):
train(epoch, model, criterion, optimizer, train_loader, writer, args)
validate(epoch, model, criterion, val_loader, writer, args)
scheduler.step()
writer.close()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import logging
import random
import numpy as np
import torch
import torch.nn as nn
from nni.nas.pytorch.callbacks import LRSchedulerCallback
from nni.nas.pytorch.callbacks import ModelCheckpoint
from nni.nas.pytorch.spos import SPOSSupernetTrainingMutator, SPOSSupernetTrainer
from dataloader import get_imagenet_iter_dali
from network import ShuffleNetV2OneShot, load_and_parse_state_dict
from utils import CrossEntropyLabelSmooth, accuracy
logger = logging.getLogger("nni.spos.supernet")
if __name__ == "__main__":
parser = argparse.ArgumentParser("SPOS Supernet Training")
parser.add_argument("--imagenet-dir", type=str, default="./data/imagenet")
parser.add_argument("--load-checkpoint", action="store_true", default=False)
parser.add_argument("--spos-preprocessing", action="store_true", default=False,
help="When true, image values will range from 0 to 255 and use BGR "
"(as in original repo).")
parser.add_argument("--workers", type=int, default=4)
parser.add_argument("--batch-size", type=int, default=768)
parser.add_argument("--epochs", type=int, default=120)
parser.add_argument("--learning-rate", type=float, default=0.5)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--weight-decay", type=float, default=4E-5)
parser.add_argument("--label-smooth", type=float, default=0.1)
parser.add_argument("--log-frequency", type=int, default=10)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--label-smoothing", type=float, default=0.1)
args = parser.parse_args()
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
torch.backends.cudnn.deterministic = True
model = ShuffleNetV2OneShot()
if args.load_checkpoint:
if not args.spos_preprocessing:
logger.warning("You might want to use SPOS preprocessing if you are loading their checkpoints.")
model.load_state_dict(load_and_parse_state_dict())
model.cuda()
if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu
model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1)))
mutator = SPOSSupernetTrainingMutator(model, flops_func=model.module.get_candidate_flops,
flops_lb=290E6, flops_ub=360E6)
criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing)
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate,
momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
lambda step: (1.0 - step / args.epochs)
if step <= args.epochs else 0,
last_epoch=-1)
train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.batch_size, args.workers,
spos_preprocessing=args.spos_preprocessing)
valid_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.batch_size, args.workers,
spos_preprocessing=args.spos_preprocessing)
trainer = SPOSSupernetTrainer(model, criterion, accuracy, optimizer,
args.epochs, train_loader, valid_loader,
mutator=mutator, batch_size=args.batch_size,
log_frequency=args.log_frequency, workers=args.workers,
callbacks=[LRSchedulerCallback(scheduler),
ModelCheckpoint("./checkpoints")])
trainer.train()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import logging
import random
import time
from itertools import cycle
import nni
import numpy as np
import torch
import torch.nn as nn
from nni.nas.pytorch.classic_nas import get_and_apply_next_architecture
from nni.nas.pytorch.utils import AverageMeterGroup
from dataloader import get_imagenet_iter_dali
from network import ShuffleNetV2OneShot, load_and_parse_state_dict
from utils import CrossEntropyLabelSmooth, accuracy
logger = logging.getLogger("nni.spos.tester")
def retrain_bn(model, criterion, max_iters, log_freq, loader):
with torch.no_grad():
logger.info("Clear BN statistics...")
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.running_mean = torch.zeros_like(m.running_mean)
m.running_var = torch.ones_like(m.running_var)
logger.info("Train BN with training set (BN sanitize)...")
model.train()
meters = AverageMeterGroup()
for step in range(max_iters):
inputs, targets = next(loader)
logits = model(inputs)
loss = criterion(logits, targets)
metrics = accuracy(logits, targets)
metrics["loss"] = loss.item()
meters.update(metrics)
if step % log_freq == 0 or step + 1 == max_iters:
logger.info("Train Step [%d/%d] %s", step + 1, max_iters, meters)
def test_acc(model, criterion, log_freq, loader):
logger.info("Start testing...")
model.eval()
meters = AverageMeterGroup()
start_time = time.time()
with torch.no_grad():
for step, (inputs, targets) in enumerate(loader):
logits = model(inputs)
loss = criterion(logits, targets)
metrics = accuracy(logits, targets)
metrics["loss"] = loss.item()
meters.update(metrics)
if step % log_freq == 0 or step + 1 == len(loader):
logger.info("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f",
step + 1, len(loader), time.time() - start_time,
meters.acc1.avg, meters.acc5.avg, meters.loss.avg)
return meters.acc1.avg
def evaluate_acc(model, criterion, args, loader_train, loader_test):
acc_before = test_acc(model, criterion, args.log_frequency, loader_test)
nni.report_intermediate_result(acc_before)
retrain_bn(model, criterion, args.train_iters, args.log_frequency, loader_train)
acc = test_acc(model, criterion, args.log_frequency, loader_test)
assert isinstance(acc, float)
nni.report_intermediate_result(acc)
nni.report_final_result(acc)
if __name__ == "__main__":
parser = argparse.ArgumentParser("SPOS Candidate Tester")
parser.add_argument("--imagenet-dir", type=str, default="./data/imagenet")
parser.add_argument("--checkpoint", type=str, default="./data/checkpoint-150000.pth.tar")
parser.add_argument("--spos-preprocessing", action="store_true", default=False,
help="When true, image values will range from 0 to 255 and use BGR "
"(as in original repo).")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--workers", type=int, default=6)
parser.add_argument("--train-batch-size", type=int, default=128)
parser.add_argument("--train-iters", type=int, default=200)
parser.add_argument("--test-batch-size", type=int, default=512)
parser.add_argument("--log-frequency", type=int, default=10)
args = parser.parse_args()
# use a fixed set of image will improve the performance
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
torch.backends.cudnn.deterministic = True
assert torch.cuda.is_available()
model = ShuffleNetV2OneShot()
criterion = CrossEntropyLabelSmooth(1000, 0.1)
get_and_apply_next_architecture(model)
model.load_state_dict(load_and_parse_state_dict(filepath=args.checkpoint))
model.cuda()
train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.train_batch_size, args.workers,
spos_preprocessing=args.spos_preprocessing,
seed=args.seed, device_id=0)
val_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.test_batch_size, args.workers,
spos_preprocessing=args.spos_preprocessing, shuffle=True,
seed=args.seed, device_id=0)
train_loader = cycle(train_loader)
evaluate_acc(model, criterion, args, train_loader, val_loader)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.nas.pytorch.spos import SPOSEvolution
from network import ShuffleNetV2OneShot
class EvolutionWithFlops(SPOSEvolution):
"""
This tuner extends the function of evolution tuner, by limiting the flops generated by tuner.
Needs a function to examine the flops.
"""
def __init__(self, flops_limit=330E6, **kwargs):
super().__init__(**kwargs)
self.model = ShuffleNetV2OneShot()
self.flops_limit = flops_limit
def _is_legal(self, cand):
if not super()._is_legal(cand):
return False
if self.model.get_candidate_flops(cand) > self.flops_limit:
return False
return True
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def accuracy(output, target, topk=(1, 5)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .evolution import SPOSEvolution
from .mutator import SPOSSupernetTrainingMutator
from .trainer import SPOSSupernetTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import re
from collections import deque
import numpy as np
from nni.tuner import Tuner
from nni.nas.pytorch.classic_nas.mutator import LAYER_CHOICE, INPUT_CHOICE
_logger = logging.getLogger(__name__)
class SPOSEvolution(Tuner):
def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1,
num_crossover=25, num_mutation=25):
"""
Initialize SPOS Evolution Tuner.
Parameters
----------
max_epochs : int
Maximum number of epochs to run.
num_select : int
Number of survival candidates of each epoch.
num_population : int
Number of candidates at the start of each epoch. If candidates generated by
crossover and mutation are not enough, the rest will be filled with random
candidates.
m_prob : float
The probability of mutation.
num_crossover : int
Number of candidates generated by crossover in each epoch.
num_mutation : int
Number of candidates generated by mutation in each epoch.
"""
assert num_population >= num_select
self.max_epochs = max_epochs
self.num_select = num_select
self.num_population = num_population
self.m_prob = m_prob
self.num_crossover = num_crossover
self.num_mutation = num_mutation
self.epoch = 0
self.candidates = []
self.search_space = None
self.random_state = np.random.RandomState(0)
# async status
self._to_evaluate_queue = deque()
self._sending_parameter_queue = deque()
self._pending_result_ids = set()
self._reward_dict = dict()
self._id2candidate = dict()
self._st_callback = None
def update_search_space(self, search_space):
"""
Handle the initialization/update event of search space.
"""
self._search_space = search_space
self._next_round()
def _next_round(self):
_logger.info("Epoch %d, generating...", self.epoch)
if self.epoch == 0:
self._get_random_population()
self.export_results(self.candidates)
else:
best_candidates = self._select_top_candidates()
self.export_results(best_candidates)
if self.epoch >= self.max_epochs:
return
self.candidates = self._get_mutation(best_candidates) + self._get_crossover(best_candidates)
self._get_random_population()
self.epoch += 1
def _random_candidate(self):
chosen_arch = dict()
for key, val in self._search_space.items():
if val["_type"] == LAYER_CHOICE:
choices = val["_value"]
index = self.random_state.randint(len(choices))
chosen_arch[key] = {"_value": choices[index], "_idx": index}
elif val["_type"] == INPUT_CHOICE:
raise NotImplementedError("Input choice is not implemented yet.")
return chosen_arch
def _add_to_evaluate_queue(self, cand):
_logger.info("Generate candidate %s, adding to eval queue.", self._get_architecture_repr(cand))
self._reward_dict[self._hashcode(cand)] = 0.
self._to_evaluate_queue.append(cand)
def _get_random_population(self):
while len(self.candidates) < self.num_population:
cand = self._random_candidate()
if self._is_legal(cand):
_logger.info("Random candidate generated.")
self._add_to_evaluate_queue(cand)
self.candidates.append(cand)
def _get_crossover(self, best):
result = []
for _ in range(10 * self.num_crossover):
cand_p1 = best[self.random_state.randint(len(best))]
cand_p2 = best[self.random_state.randint(len(best))]
assert cand_p1.keys() == cand_p2.keys()
cand = {k: cand_p1[k] if self.random_state.randint(2) == 0 else cand_p2[k]
for k in cand_p1.keys()}
if self._is_legal(cand):
result.append(cand)
self._add_to_evaluate_queue(cand)
if len(result) >= self.num_crossover:
break
_logger.info("Found %d architectures with crossover.", len(result))
return result
def _get_mutation(self, best):
result = []
for _ in range(10 * self.num_mutation):
cand = best[self.random_state.randint(len(best))].copy()
mutation_sample = np.random.random_sample(len(cand))
for s, k in zip(mutation_sample, cand):
if s < self.m_prob:
choices = self._search_space[k]["_value"]
index = self.random_state.randint(len(choices))
cand[k] = {"_value": choices[index], "_idx": index}
if self._is_legal(cand):
result.append(cand)
self._add_to_evaluate_queue(cand)
if len(result) >= self.num_mutation:
break
_logger.info("Found %d architectures with mutation.", len(result))
return result
def _get_architecture_repr(self, cand):
return re.sub(r"\".*?\": \{\"_idx\": (\d+), \"_value\": \".*?\"\}", r"\1",
self._hashcode(cand))
def _is_legal(self, cand):
if self._hashcode(cand) in self._reward_dict:
return False
return True
def _select_top_candidates(self):
reward_query = lambda cand: self._reward_dict[self._hashcode(cand)]
_logger.info("All candidate rewards: %s", list(map(reward_query, self.candidates)))
result = sorted(self.candidates, key=reward_query, reverse=True)[:self.num_select]
_logger.info("Best candidate rewards: %s", list(map(reward_query, result)))
return result
@staticmethod
def _hashcode(d):
return json.dumps(d, sort_keys=True)
def _bind_and_send_parameters(self):
"""
There are two types of resources: parameter ids and candidates. This function is called at
necessary times to bind these resources to send new trials with st_callback.
"""
result = []
while self._sending_parameter_queue and self._to_evaluate_queue:
parameter_id = self._sending_parameter_queue.popleft()
parameters = self._to_evaluate_queue.popleft()
self._id2candidate[parameter_id] = parameters
result.append(parameters)
self._pending_result_ids.add(parameter_id)
self._st_callback(parameter_id, parameters)
_logger.info("Send parameter [%d] %s.", parameter_id, self._get_architecture_repr(parameters))
return result
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""
Callback function necessary to implement a tuner. This will put more parameter ids into the
parameter id queue.
"""
if "st_callback" in kwargs and self._st_callback is None:
self._st_callback = kwargs["st_callback"]
for parameter_id in parameter_id_list:
self._sending_parameter_queue.append(parameter_id)
self._bind_and_send_parameters()
return [] # always not use this. might induce problem of over-sending
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Callback function. Receive a trial result.
"""
_logger.info("Candidate %d, reported reward %f", parameter_id, value)
self._reward_dict[self._hashcode(self._id2candidate[parameter_id])] = value
def trial_end(self, parameter_id, success, **kwargs):
"""
Callback function when a trial is ended and resource is released.
"""
self._pending_result_ids.remove(parameter_id)
if not self._pending_result_ids and not self._to_evaluate_queue:
# a new epoch now
self._next_round()
assert self._st_callback is not None
self._bind_and_send_parameters()
def export_results(self, result):
"""
Export a number of candidates to `checkpoints` dir.
Parameters
----------
result : dict
"""
os.makedirs("checkpoints", exist_ok=True)
for i, cand in enumerate(result):
converted = dict()
for cand_key, cand_val in cand.items():
onehot = [k == cand_val["_idx"] for k in range(len(self._search_space[cand_key]["_value"]))]
converted[cand_key] = onehot
with open(os.path.join("checkpoints", "%03d_%03d.json" % (self.epoch, i)), "w") as fp:
json.dump(converted, fp)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import numpy as np
from nni.nas.pytorch.random import RandomMutator
_logger = logging.getLogger(__name__)
class SPOSSupernetTrainingMutator(RandomMutator):
def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None,
flops_bin_num=7, flops_sample_timeout=500):
"""
Parameters
----------
model : nn.Module
flops_func : callable
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
is None, functions related to flops will be deactivated.
flops_lb : number
Lower bound of flops.
flops_ub : number
Upper bound of flops.
flops_bin_num : number
Number of bins divided for the interval of flops to ensure the uniformity. Bigger number will be more
uniform, but the sampling will be slower.
flops_sample_timeout : int
Maximum number of attempts to sample before giving up and use a random candidate.
"""
super().__init__(model)
self._flops_func = flops_func
if self._flops_func is not None:
self._flops_bin_num = flops_bin_num
self._flops_bins = [flops_lb + (flops_ub - flops_lb) / flops_bin_num * i for i in range(flops_bin_num + 1)]
self._flops_sample_timeout = flops_sample_timeout
def sample_search(self):
"""
Sample a candidate for training. When `flops_func` is not None, candidates will be sampled uniformly
relative to flops.
Returns
-------
dict
"""
if self._flops_func is not None:
for times in range(self._flops_sample_timeout):
idx = np.random.randint(self._flops_bin_num)
cand = super().sample_search()
if self._flops_bins[idx] <= self._flops_func(cand) <= self._flops_bins[idx + 1]:
_logger.debug("Sampled candidate flops %f in %d times.", cand, times)
return cand
_logger.warning("Failed to sample a flops-valid candidate within %d tries.", self._flops_sample_timeout)
return super().sample_search()
def sample_final(self):
"""
Implement only to suffice the interface of Mutator.
"""
return self.sample_search()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import SPOSSupernetTrainingMutator
logger = logging.getLogger(__name__)
class SPOSSupernetTrainer(Trainer):
"""
This trainer trains a supernet that can be used for evolution search.
"""
def __init__(self, model, loss, metrics,
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None):
assert torch.cuda.is_available()
super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model),
loss, metrics, optimizer, num_epochs, None, None,
batch_size, workers, device, log_frequency, callbacks)
self.train_loader = train_loader
self.valid_loader = valid_loader
def train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
self.optimizer.zero_grad()
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
loss.backward()
self.optimizer.step()
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def validate_one_epoch(self, epoch):
self.model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.valid_loader), meters)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment