Commit 13d03757 authored by Houwen Peng's avatar Houwen Peng Committed by Yuge Zhang
Browse files

integrate c-darts nas algorithm (#1955)

parent a9711e24
......@@ -126,6 +126,7 @@ Within the following table, we summarized the current NNI capabilities, we are g
<li><a href="docs/en_US/NAS/Overview.md#enas">ENAS</a></li>
<li><a href="docs/en_US/NAS/Overview.md#darts">DARTS</a></li>
<li><a href="docs/en_US/NAS/Overview.md#p-darts">P-DARTS</a></li>
<li><a href="docs/en_US/NAS/Overview.md#cdarts">CDARTS</a></li>
<li><a href="docs/en_US/Tuner/BuiltinTuner.md#NetworkMorphism">Network Morphism</a> </li>
</ul>
</ul>
......
# CDARTS
## Introduction
CDARTS builds a cyclic feedback mechanism between the search and evaluation networks. First, the search network generates an initial topology for evaluation, so that the weights of the evaluation network can be optimized. Second, the architecture topology in the search network is further optimized by the label supervision in classification, as well as the regularization from the evaluation network through feature distillation. Repeating the above cycle results in a joint optimization of the search and evaluation networks, and thus enables the evolution of the topology to fit the final evaluation network.
In implementation of `CdartsTrainer`, it first instantiates two models and two mutators (one for each). The first model is the so-called "search network", which is mutated with a `RegularizedDartsMutator` -- a mutator with subtle differences with `DartsMutator`. The second model is the "evaluation network", which is mutated with a discrete mutator that leverages the previous search network mutator, to sample a single path each time. Trainers train models and mutators alternatively. Users can refer to [references](#reference) if they are interested in more details on these trainers and mutators.
## Reproduction Results
This is CDARTS based on the NNI platform, which currently supports CIFAR10 search and retrain. ImageNet search and retrain should also be supported, and we provide corresponding interfaces. Our reproduced results on NNI are slightly lower than the paper, but much higher than the original DARTS. Here we show the results of three independent experiments on CIFAR10.
| Runs | Paper | NNI |
| ---- |:-------------:| :-----:|
| 1 | 97.52 | 97.44 |
| 2 | 97.53 | 97.48 |
| 3 | 97.58 | 97.56 |
## Examples
[Example code](https://github.com/microsoft/nni/tree/master/examples/nas/cdarts)
```bash
# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder.
git clone https://github.com/Microsoft/nni.git
# install apex for distributed training.
git clone https://github.com/NVIDIA/apex
cd apex
python setup.py install --cpp_ext --cuda_ext
# search the best architecture
cd examples/nas/cdarts
bash run_search_cifar.sh
# train the best architecture.
bash run_retrain_cifar.sh
```
## Reference
### PyTorch
```eval_rst
.. autoclass:: nni.nas.pytorch.cdarts.CdartsTrainer
:members:
.. automethod:: __init__
.. autoclass:: nni.nas.pytorch.cdarts.RegularizedDartsMutator
:members:
.. autoclass:: nni.nas.pytorch.cdarts.DartsDiscreteMutator
:members:
.. automethod:: __init__
.. autoclass:: nni.nas.pytorch.cdarts.RegularizedMutatorParallel
:members:
```
......@@ -22,6 +22,7 @@ NNI supports below NAS algorithms now and is adding more. User can reproduce an
| [DARTS](DARTS.md) | [DARTS: Differentiable Architecture Search](https://arxiv.org/abs/1806.09055) introduces a novel algorithm for differentiable network architecture search on bilevel optimization. |
| [P-DARTS](PDARTS.md) | [Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation](https://arxiv.org/abs/1904.12760) is based on DARTS. It introduces an efficient algorithm which allows the depth of searched architectures to grow gradually during the training procedure. |
| [SPOS](SPOS.md) | [Single Path One-Shot Neural Architecture Search with Uniform Sampling](https://arxiv.org/abs/1904.00420) constructs a simplified supernet trained with an uniform path sampling method, and applies an evolutionary algorithm to efficiently search for the best-performing architectures. |
| [CDARTS](CDARTS.md) | [Cyclic Differentiable Architecture Search](https://arxiv.org/abs/****) builds a cyclic feedback mechanism between the search and evaluation networks. It introduces a cyclic differentiable architecture search framework which integrates the two networks into a unified architecture.|
One-shot algorithms run **standalone without nnictl**. Only PyTorch version has been implemented. Tensorflow 2.x will be supported in future release.
......
......@@ -47,6 +47,9 @@ extensions = [
'sphinx.ext.napoleon',
]
# Add mock modules
autodoc_mock_imports = ['apex']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
......
......@@ -24,3 +24,4 @@ For details, please refer to the following tutorials:
DARTS <NAS/DARTS>
P-DARTS <NAS/PDARTS>
SPOS <NAS/SPOS>
CDARTS <NAS/CDARTS>
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch.nn as nn
class DistillHeadCIFAR(nn.Module):
def __init__(self, C, size, num_classes, bn_affine=False):
"""assuming input size 8x8 or 16x16"""
super(DistillHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(),
nn.AvgPool2d(size, stride=2, padding=0, count_include_pad=False), # image size = 2 x 2 / 6 x 6
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128, affine=bn_affine),
nn.ReLU(),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768, affine=bn_affine),
nn.ReLU()
)
self.classifier = nn.Linear(768, num_classes)
self.gap = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
x = self.features(x)
x = self.gap(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class DistillHeadImagenet(nn.Module):
def __init__(self, C, size, num_classes, bn_affine=False):
"""assuming input size 7x7 or 14x14"""
super(DistillHeadImagenet, self).__init__()
self.features = nn.Sequential(
nn.ReLU(),
nn.AvgPool2d(size, stride=2, padding=0, count_include_pad=False), # image size = 2 x 2 / 6 x 6
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128, affine=bn_affine),
nn.ReLU(),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768, affine=bn_affine),
nn.ReLU()
)
self.classifier = nn.Linear(768, num_classes)
self.gap = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
x = self.features(x)
x = self.gap(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class AuxiliaryHeadCIFAR(nn.Module):
def __init__(self, C, size=5, num_classes=10):
"""assuming input size 8x8"""
super(AuxiliaryHeadCIFAR, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class AuxiliaryHeadImageNet(nn.Module):
def __init__(self, C, size=5, num_classes=1000):
"""assuming input size 7x7"""
super(AuxiliaryHeadImageNet, self).__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(size, stride=2, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
# NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
# Commenting it out for consistency with the experiments in the paper.
# nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0), -1))
return x
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
from functools import partial
def get_parser(name):
""" make default formatted parser """
parser = argparse.ArgumentParser(name, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# print default value always
parser.add_argument = partial(parser.add_argument, help=' ')
return parser
class BaseConfig(argparse.Namespace):
def print_params(self, prtf=print):
prtf("")
prtf("Parameters:")
for attr, value in sorted(vars(self).items()):
prtf("{}={}".format(attr.upper(), value))
prtf("")
def as_markdown(self):
""" Return configs as markdown format """
text = "|name|value| \n|-|-| \n"
for attr, value in sorted(vars(self).items()):
text += "|{}|{}| \n".format(attr, value)
return text
class SearchConfig(BaseConfig):
def build_parser(self):
parser = get_parser("Search config")
########### basic settings ############
parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'imagenet'])
parser.add_argument('--n_classes', type=int, default=10)
parser.add_argument('--stem_multiplier', type=int, default=3)
parser.add_argument('--init_channels', type=int, default=16)
parser.add_argument('--data_dir', type=str, default='data/cifar', help='cifar dataset')
parser.add_argument('--output_path', type=str, default='./outputs', help='')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--log_frequency', type=int, default=10, help='print frequency')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--workers', type=int, default=4, help='# of workers')
parser.add_argument('--steps_per_epoch', type=int, default=None, help='how many steps per epoch, use None for one pass of dataset')
########### learning rate ############
parser.add_argument('--w_lr', type=float, default=0.05, help='lr for weights')
parser.add_argument('--w_momentum', type=float, default=0.9, help='momentum for weights')
parser.add_argument('--w_weight_decay', type=float, default=3e-4, help='weight decay for weights')
parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping for weights')
parser.add_argument('--alpha_lr', type=float, default=6e-4, help='lr for alpha')
parser.add_argument('--alpha_weight_decay', type=float, default=1e-3, help='weight decay for alpha')
parser.add_argument('--nasnet_lr', type=float, default=0.1, help='lr of nasnet')
########### alternate training ############
parser.add_argument('--epochs', type=int, default=32, help='# of search epochs')
parser.add_argument('--warmup_epochs', type=int, default=2, help='# warmup epochs of super model')
parser.add_argument('--loss_alpha', type=float, default=1, help='loss alpha')
parser.add_argument('--loss_T', type=float, default=2, help='loss temperature')
parser.add_argument('--interactive_type', type=str, default='kl', choices=['kl', 'smoothl1'])
parser.add_argument('--sync_bn', action='store_true', default=False, help='whether to sync bn')
parser.add_argument('--use_apex', action='store_true', default=False, help='whether to use apex')
parser.add_argument('--regular_ratio', type=float, default=0.5, help='regular ratio')
parser.add_argument('--regular_coeff', type=float, default=5, help='regular coefficient')
parser.add_argument('--fix_head', action='store_true', default=False, help='whether to fix head')
parser.add_argument('--share_module', action='store_true', default=False, help='whether to share stem and aux head')
########### data augument ############
parser.add_argument('--aux_weight', type=float, default=0.4, help='auxiliary loss weight')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path prob')
parser.add_argument('--use_aa', action='store_true', default=False, help='whether to use aa')
parser.add_argument('--mixup_alpha', default=1., type=float, help='mixup interpolation coefficient (default: 1)')
########### distributed ############
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--world_size", default=1, type=int)
parser.add_argument('--dist_url', default='tcp://127.0.0.1:23456', type=str, help='url used to set up distributed training')
parser.add_argument('--distributed', action='store_true', help='run model distributed mode')
return parser
def __init__(self):
parser = self.build_parser()
args = parser.parse_args()
super().__init__(**vars(args))
class RetrainConfig(BaseConfig):
def build_parser(self):
parser = get_parser("Retrain config")
parser.add_argument('--dataset', default="cifar10", choices=['cifar10', 'cifar100', 'imagenet'])
parser.add_argument('--data_dir', type=str, default='data/cifar', help='cifar dataset')
parser.add_argument('--output_path', type=str, default='./outputs', help='')
parser.add_argument("--arc_checkpoint", default="epoch_02.json")
parser.add_argument('--log_frequency', type=int, default=10, help='print frequency')
########### model settings ############
parser.add_argument('--n_classes', type=int, default=10)
parser.add_argument('--input_channels', type=int, default=3)
parser.add_argument('--stem_multiplier', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--eval_batch_size', type=int, default=500, help='batch size for validation')
parser.add_argument('--lr', type=float, default=0.025, help='lr for weights')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping for weights')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
parser.add_argument('--epochs', type=int, default=600, help='# of training epochs')
parser.add_argument('--warmup_epochs', type=int, default=5, help='# warmup')
parser.add_argument('--init_channels', type=int, default=36)
parser.add_argument('--layers', type=int, default=20, help='# of layers')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--workers', type=int, default=4, help='# of workers')
parser.add_argument('--aux_weight', type=float, default=0.4, help='auxiliary loss weight')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path prob')
########### data augmentation ############
parser.add_argument('--use_aa', action='store_true', default=False, help='whether to use aa')
parser.add_argument('--mixup_alpha', default=1., type=float, help='mixup interpolation coefficient')
########### distributed ############
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--world_size", default=1, type=int)
parser.add_argument('--dist_url', default='tcp://127.0.0.1:23456', type=str, help='url used to set up distributed training')
parser.add_argument('--distributed', action='store_true', help='run model distributed mode')
return parser
def __init__(self):
parser = self.build_parser()
args = parser.parse_args()
super().__init__(**vars(args))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
from datasets.data_utils import CIFAR10Policy, Cutout
from datasets.data_utils import SubsetDistributedSampler
def data_transforms_cifar(config, cutout=False):
CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
if config.use_aa:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4, fill=128),
transforms.RandomHorizontalFlip(), CIFAR10Policy(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
else:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
if cutout:
train_transform.transforms.append(Cutout(config.cutout_length))
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
])
return train_transform, valid_transform
def get_search_datasets(config):
dataset = config.dataset.lower()
if dataset == 'cifar10':
dset_cls = dset.CIFAR10
n_classes = 10
elif dataset == 'cifar100':
dset_cls = dset.CIFAR100
n_classes = 100
else:
raise Exception("Not support dataset!")
train_transform, valid_transform = data_transforms_cifar(config, cutout=False)
train_data = dset_cls(root=config.data_dir, train=True, download=True, transform=train_transform)
test_data = dset_cls(root=config.data_dir, train=False, download=True, transform=valid_transform)
num_train = len(train_data)
indices = list(range(num_train))
split_mid = int(np.floor(0.5 * num_train))
if config.distributed:
train_sampler = SubsetDistributedSampler(train_data, indices[:split_mid])
valid_sampler = SubsetDistributedSampler(train_data, indices[split_mid:num_train])
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split_mid])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split_mid:num_train])
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=config.batch_size,
sampler=train_sampler,
pin_memory=False, num_workers=config.workers)
valid_loader = torch.utils.data.DataLoader(
train_data, batch_size=config.batch_size,
sampler=valid_sampler,
pin_memory=False, num_workers=config.workers)
return [train_loader, valid_loader], [train_sampler, valid_sampler]
def get_augment_datasets(config):
dataset = config.dataset.lower()
if dataset == 'cifar10':
dset_cls = dset.CIFAR10
elif dataset == 'cifar100':
dset_cls = dset.CIFAR100
else:
raise Exception("Not support dataset!")
train_transform, valid_transform = data_transforms_cifar(config, cutout=True)
train_data = dset_cls(root=config.data_dir, train=True, download=True, transform=train_transform)
test_data = dset_cls(root=config.data_dir, train=False, download=True, transform=valid_transform)
if config.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_data)
else:
train_sampler = None
test_sampler = None
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=config.batch_size,
sampler=train_sampler,
pin_memory=True, num_workers=config.workers)
test_loader = torch.utils.data.DataLoader(
test_data, batch_size=config.eval_batch_size,
sampler=test_sampler,
pin_memory=True, num_workers=config.workers)
return [train_loader, test_loader], [train_sampler, test_sampler]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import random
import numpy as np
import torch
import torch.distributed as dist
from PIL import Image, ImageEnhance, ImageOps
from torch.utils.data import Sampler
class SubsetDistributedSampler(Sampler):
"""
Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
Dataset is assumed to be of constant size.
"""
def __init__(self, dataset, indices, num_replicas=None, rank=None, shuffle=True):
"""
Initialization.
Parameters
----------
dataset : torch.utils.data.Dataset
Dataset used for sampling.
num_replicas : int
Number of processes participating in distributed training. Default: World size.
rank : int
Rank of the current process within num_replicas. Default: Current rank.
shuffle : bool
If true (default), sampler will shuffle the indices.
"""
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.indices = indices
self.num_samples = int(math.ceil(len(self.indices) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
# indices = torch.randperm(len(self.dataset), generator=g).tolist()
indices = list(self.indices[i] for i in torch.randperm(len(self.indices)))
else:
# indices = list(range(len(self.dataset)))
indices = self.indices
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class data_prefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1, 3, 1, 1)
self.preload()
def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
self.next_input = None
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
self.next_input = self.next_input.float()
self.next_input = self.next_input.sub_(self.mean).div_(self.std)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
self.preload()
return input, target
class Cutout(object):
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
class ImageNetPolicy(object):
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
Example:
>>> policy = ImageNetPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment ImageNet Policy"
class CIFAR10Policy(object):
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
Example:
>>> policy = CIFAR10Policy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> CIFAR10Policy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment CIFAR10 Policy"
class SVHNPolicy(object):
""" Randomly choose one of the best 25 Sub-policies on SVHN.
Example:
>>> policy = SVHNPolicy()
>>> transformed = policy(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> SVHNPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
self.policies = [
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
]
def __call__(self, img):
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
return "AutoAugment SVHN Policy"
class SubPolicy(object):
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
ranges = {
"shearX": np.linspace(0, 0.3, 10),
"shearY": np.linspace(0, 0.3, 10),
"translateX": np.linspace(0, 150 / 331, 10),
"translateY": np.linspace(0, 150 / 331, 10),
"rotate": np.linspace(0, 30, 10),
"color": np.linspace(0.0, 0.9, 10),
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
"solarize": np.linspace(256, 0, 10),
"contrast": np.linspace(0.0, 0.9, 10),
"sharpness": np.linspace(0.0, 0.9, 10),
"brightness": np.linspace(0.0, 0.9, 10),
"autocontrast": [0] * 10,
"equalize": [0] * 10,
"invert": [0] * 10
}
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
func = {
"shearX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC, fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
}
self.p1 = p1
self.operation1 = func[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
self.p2 = p2
self.operation2 = func[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
if random.random() < self.p1:
img = self.operation1(img, self.magnitude1)
if random.random() < self.p2:
img = self.operation2(img, self.magnitude2)
return img
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8)
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
if (nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
def mixup_data(x, y, alpha=1.0, use_cuda=True):
'''Returns mixed inputs, pairs of targets, and lambda'''
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size()[0]
if use_cuda:
index = torch.randperm(batch_size).cuda()
else:
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import numpy as np
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
from datasets.data_utils import ImageNetPolicy
from datasets.data_utils import SubsetDistributedSampler
def _imagenet_dataset(config):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dir = os.path.join(config.data_dir, "train")
test_dir = os.path.join(config.data_dir, "val")
if hasattr(config, "use_aa") and config.use_aa:
train_data = dset.ImageFolder(
train_dir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
ImageNetPolicy(),
transforms.ToTensor(),
normalize,
]))
else:
train_data = dset.ImageFolder(
train_dir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.2),
transforms.ToTensor(),
normalize,
]))
test_data = dset.ImageFolder(
test_dir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
return train_data, test_data
def get_search_datasets(config):
train_data, test_data = _imagenet_dataset(config)
num_train = len(train_data)
indices = list(range(num_train))
split_mid = int(np.floor(0.5 * num_train))
if config.distributed:
train_sampler = SubsetDistributedSampler(train_data, indices[:split_mid])
valid_sampler = SubsetDistributedSampler(train_data, indices[split_mid:num_train])
else:
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split_mid])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split_mid:num_train])
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=config.batch_size,
sampler=train_sampler,
pin_memory=True, num_workers=config.workers)
valid_loader = torch.utils.data.DataLoader(
train_data, batch_size=config.batch_size,
sampler=valid_sampler,
pin_memory=True, num_workers=config.workers)
return [train_loader, valid_loader], [train_sampler, valid_sampler]
def get_augment_datasets(config):
train_data, test_data = _imagenet_dataset(config)
if config.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_data)
else:
train_sampler = test_sampler = None
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=config.batch_size,
sampler=train_sampler,
pin_memory=True, num_workers=config.workers)
test_loader = torch.utils.data.DataLoader(
test_data, batch_size=config.batch_size,
sampler=test_sampler,
pin_memory=True, num_workers=config.workers)
return [train_loader, test_loader], [train_sampler, test_sampler]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
- Genotype: normal/reduce gene + normal/reduce cell output connection (concat)
- gene: discrete ops information (w/o output connection)
- dag: real ops (can be mixed or discrete, but Genotype has only discrete information itself)
"""
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import ops
from ops import PRIMITIVES
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
def to_dag(C_in, gene, reduction, bn_affine=True):
""" generate discrete ops from gene """
dag = nn.ModuleList()
for edges in gene:
row = nn.ModuleList()
for op_name, s_idx in edges:
# reduction cell & from input nodes => stride = 2
stride = 2 if reduction and s_idx < 2 else 1
op = ops.OPS[op_name](C_in, stride, bn_affine)
if not isinstance(op, ops.Identity): # Identity does not use drop path
op = nn.Sequential(
op,
ops.DropPath_()
)
op.s_idx = s_idx
row.append(op)
dag.append(row)
return dag
def from_str(s):
""" generate genotype from string
e.g. "Genotype(
normal=[[('sep_conv_3x3', 0), ('sep_conv_3x3', 1)],
[('sep_conv_3x3', 1), ('dil_conv_3x3', 2)],
[('sep_conv_3x3', 1), ('sep_conv_3x3', 2)],
[('sep_conv_3x3', 1), ('dil_conv_3x3', 4)]],
normal_concat=range(2, 6),
reduce=[[('max_pool_3x3', 0), ('max_pool_3x3', 1)],
[('max_pool_3x3', 0), ('skip_connect', 2)],
[('max_pool_3x3', 0), ('skip_connect', 2)],
[('max_pool_3x3', 0), ('skip_connect', 2)]],
reduce_concat=range(2, 6))"
"""
genotype = eval(s)
return genotype
def parse(alpha, beta, k):
"""
parse continuous alpha to discrete gene.
alpha is ParameterList:
ParameterList [
Parameter(n_edges1, n_ops),
Parameter(n_edges2, n_ops),
...
]
beta is ParameterList:
ParameterList [
Parameter(n_edges1),
Parameter(n_edges2),
...
]
gene is list:
[
[('node1_ops_1', node_idx), ..., ('node1_ops_k', node_idx)],
[('node2_ops_1', node_idx), ..., ('node2_ops_k', node_idx)],
...
]
each node has two edges (k=2) in CNN.
"""
gene = []
assert PRIMITIVES[-1] == 'none' # 'none' is implemented in mutator now
# 1) Convert the mixed op to discrete edge (single op) by choosing top-1 weight edge
# 2) Choose top-k edges per node by edge score (top-1 weight in edge)
# output the connect idx[(node_idx, connect_idx, op_idx).... () ()]
connect_idx = []
for edges, w in zip(alpha, beta):
# edges: Tensor(n_edges, n_ops)
edge_max, primitive_indices = torch.topk((w.view(-1, 1) * edges)[:, :-1], 1) # ignore 'none'
topk_edge_values, topk_edge_indices = torch.topk(edge_max.view(-1), k)
node_gene = []
node_idx = []
for edge_idx in topk_edge_indices:
prim_idx = primitive_indices[edge_idx]
prim = PRIMITIVES[prim_idx]
node_gene.append((prim, edge_idx.item()))
node_idx.append((edge_idx.item(), prim_idx.item()))
gene.append(node_gene)
connect_idx.append(node_idx)
return gene, connect_idx
def parse_gumbel(alpha, beta, k):
"""
parse continuous alpha to discrete gene.
alpha is ParameterList:
ParameterList [
Parameter(n_edges1, n_ops),
Parameter(n_edges2, n_ops),
...
]
beta is ParameterList:
ParameterList [
Parameter(n_edges1),
Parameter(n_edges2),
...
]
gene is list:
[
[('node1_ops_1', node_idx), ..., ('node1_ops_k', node_idx)],
[('node2_ops_1', node_idx), ..., ('node2_ops_k', node_idx)],
...
]
each node has two edges (k=2) in CNN.
"""
gene = []
assert PRIMITIVES[-1] == 'none' # assume last PRIMITIVE is 'none'
# 1) Convert the mixed op to discrete edge (single op) by choosing top-1 weight edge
# 2) Choose top-k edges per node by edge score (top-1 weight in edge)
# output the connect idx[(node_idx, connect_idx, op_idx).... () ()]
connect_idx = []
for edges, w in zip(alpha, beta):
# edges: Tensor(n_edges, n_ops)
discrete_a = F.gumbel_softmax(edges[:, :-1].reshape(-1), tau=1, hard=True)
for i in range(k-1):
discrete_a = discrete_a + F.gumbel_softmax(edges[:, :-1].reshape(-1), tau=1, hard=True)
discrete_a = discrete_a.reshape(-1, len(PRIMITIVES)-1)
reserved_edge = (discrete_a > 0).nonzero()
node_gene = []
node_idx = []
for i in range(reserved_edge.shape[0]):
edge_idx = reserved_edge[i][0].item()
prim_idx = reserved_edge[i][1].item()
prim = PRIMITIVES[prim_idx]
node_gene.append((prim, edge_idx))
node_idx.append((edge_idx, prim_idx))
gene.append(node_gene)
connect_idx.append(node_idx)
return gene, connect_idx
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
import ops
import numpy as np
from nni.nas.pytorch import mutables
from utils import parse_results
from aux_head import DistillHeadCIFAR, DistillHeadImagenet, AuxiliaryHeadCIFAR, AuxiliaryHeadImageNet
class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__()
self.ops = nn.ModuleList()
choice_keys = []
for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(mutables.LayerChoice([ops.OPS[k](channels, stride, False) for k in ops.PRIMITIVES],
key=choice_keys[-1]))
self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes)
out = [op(node) for op, node in zip(self.ops, prev_nodes)]
out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out)
class Cell(nn.Module):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)
# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(2, self.n_nodes + 2):
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
depth, channels, 2 if reduction else 0))
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
for node in self.mutable_ops:
cur_tensor = node(tensors)
tensors.append(cur_tensor)
output = torch.cat(tensors[2:], dim=1)
return output
class Model(nn.Module):
def __init__(self, dataset, n_layers, in_channels=3, channels=16, n_nodes=4, retrain=False, shared_modules=None):
super().__init__()
assert dataset in ["cifar10", "imagenet"]
self.dataset = dataset
self.input_size = 32 if dataset == "cifar" else 224
self.in_channels = in_channels
self.channels = channels
self.n_nodes = n_nodes
self.aux_size = {2 * n_layers // 3: self.input_size // 4}
if dataset == "cifar10":
self.n_classes = 10
self.aux_head_class = AuxiliaryHeadCIFAR if retrain else DistillHeadCIFAR
if not retrain:
self.aux_size = {n_layers // 3: 6, 2 * n_layers // 3: 6}
elif dataset == "imagenet":
self.n_classes = 1000
self.aux_head_class = AuxiliaryHeadImageNet if retrain else DistillHeadImagenet
if not retrain:
self.aux_size = {n_layers // 3: 6, 2 * n_layers // 3: 5}
self.n_layers = n_layers
self.aux_head = nn.ModuleDict()
self.ensemble_param = nn.Parameter(torch.rand(len(self.aux_size) + 1) / (len(self.aux_size) + 1)) \
if not retrain else None
stem_multiplier = 3 if dataset == "cifar" else 1
c_cur = stem_multiplier * self.channels
self.shared_modules = {} # do not wrap with ModuleDict
if shared_modules is not None:
self.stem = shared_modules["stem"]
else:
self.stem = nn.Sequential(
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
nn.BatchNorm2d(c_cur)
)
self.shared_modules["stem"] = self.stem
# for the first cell, stem is used for both s0 and s1
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
channels_pp, channels_p, c_cur = c_cur, c_cur, channels
self.cells = nn.ModuleList()
reduction_p, reduction = False, False
aux_head_count = 0
for i in range(n_layers):
reduction_p, reduction = reduction, False
if i in [n_layers // 3, 2 * n_layers // 3]:
c_cur *= 2
reduction = True
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell)
c_cur_out = c_cur * n_nodes
if i in self.aux_size:
if shared_modules is not None:
self.aux_head[str(i)] = shared_modules["aux" + str(aux_head_count)]
else:
self.aux_head[str(i)] = self.aux_head_class(c_cur_out, self.aux_size[i], self.n_classes)
self.shared_modules["aux" + str(aux_head_count)] = self.aux_head[str(i)]
aux_head_count += 1
channels_pp, channels_p = channels_p, c_cur_out
self.gap = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(channels_p, self.n_classes)
def forward(self, x):
s0 = s1 = self.stem(x)
outputs = []
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1)
if str(i) in self.aux_head:
outputs.append(self.aux_head[str(i)](s1))
out = self.gap(s1)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
outputs.append(logits)
if self.ensemble_param is None:
assert len(outputs) == 2
return outputs[1], outputs[0]
else:
em_output = torch.cat([(e * o) for e, o in zip(F.softmax(self.ensemble_param, dim=0), outputs)], 0)
return logits, em_output
def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath):
module.p = p
def plot_genotype(self, results, logger):
genotypes = parse_results(results, self.n_nodes)
logger.info(genotypes)
return genotypes
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
OPS = {
'avg_pool_3x3': lambda C, stride, affine: PoolWithoutBN('avg', C, 3, stride, 1, affine=affine),
'max_pool_3x3': lambda C, stride, affine: PoolWithoutBN('max', C, 3, stride, 1, affine=affine),
'skip_connect': lambda C, stride, affine: nn.Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9
'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine)
}
PRIMITIVES = [
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect', # identity
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
]
class DropPath(nn.Module):
def __init__(self, p=0.):
"""
Drop path with probability.
Parameters
----------
p : float
Probability of an path to be zeroed.
"""
super().__init__()
self.p = p
def forward(self, x):
if self.training and self.p > 0.:
keep_prob = 1. - self.p
# per data point mask
mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob)
return x / keep_prob * mask
return x
class PoolWithoutBN(nn.Module):
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise NotImplementedError("Pool doesn't support pooling type other than max and avg.")
def forward(self, x):
out = self.pool(x)
return out
class StdConv(nn.Module):
"""
Standard conv: ReLU - Conv - BN
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
class FacConv(nn.Module):
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
class DilConv(nn.Module):
"""
(Dilated) depthwise separable conv.
ReLU - (Dilated) depthwise separable - Pointwise - BN.
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
class SepConv(nn.Module):
"""
Depthwise separable conv.
DilConv(dilation=1) * 2.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
)
def forward(self, x):
return self.net(x)
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise (stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import time
from argparse import ArgumentParser
import torch
import torch.nn as nn
import apex # pylint: disable=import-error
import datasets
import utils
from apex.parallel import DistributedDataParallel # pylint: disable=import-error
from config import RetrainConfig
from datasets.cifar import get_augment_datasets
from model import Model
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeterGroup
def train(logger, config, train_loader, model, optimizer, criterion, epoch, main_proc):
meters = AverageMeterGroup()
cur_lr = optimizer.param_groups[0]["lr"]
if main_proc:
logger.info("Epoch %d LR %.6f", epoch, cur_lr)
model.train()
for step, (x, y) in enumerate(train_loader):
x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
optimizer.zero_grad()
logits, aux_logits = model(x)
loss = criterion(logits, y)
if config.aux_weight > 0.:
loss += config.aux_weight * criterion(aux_logits, y)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
optimizer.step()
prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = utils.reduce_metrics(metrics, config.distributed)
meters.update(metrics)
if main_proc and (step % config.log_frequency == 0 or step + 1 == len(train_loader)):
logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, config.epochs, step + 1, len(train_loader), meters)
if main_proc:
logger.info("Train: [%d/%d] Final Prec@1 %.4f Prec@5 %.4f", epoch + 1, config.epochs, meters.prec1.avg, meters.prec5.avg)
def validate(logger, config, valid_loader, model, criterion, epoch, main_proc):
meters = AverageMeterGroup()
model.eval()
with torch.no_grad():
for step, (x, y) in enumerate(valid_loader):
x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True)
logits, _ = model(x)
loss = criterion(logits, y)
prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = utils.reduce_metrics(metrics, config.distributed)
meters.update(metrics)
if main_proc and (step % config.log_frequency == 0 or step + 1 == len(valid_loader)):
logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, config.epochs, step + 1, len(valid_loader), meters)
if main_proc:
logger.info("Train: [%d/%d] Final Prec@1 %.4f Prec@5 %.4f", epoch + 1, config.epochs, meters.prec1.avg, meters.prec5.avg)
return meters.prec1.avg, meters.prec5.avg
def main():
config = RetrainConfig()
main_proc = not config.distributed or config.local_rank == 0
if config.distributed:
torch.cuda.set_device(config.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method=config.dist_url,
rank=config.local_rank, world_size=config.world_size)
if main_proc:
os.makedirs(config.output_path, exist_ok=True)
if config.distributed:
torch.distributed.barrier()
logger = utils.get_logger(os.path.join(config.output_path, 'search.log'))
if main_proc:
config.print_params(logger.info)
utils.reset_seed(config.seed)
loaders, samplers = get_augment_datasets(config)
train_loader, valid_loader = loaders
train_sampler, valid_sampler = samplers
model = Model(config.dataset, config.layers, in_channels=config.input_channels, channels=config.init_channels, retrain=True).cuda()
if config.label_smooth > 0:
criterion = utils.CrossEntropyLabelSmooth(config.n_classes, config.label_smooth)
else:
criterion = nn.CrossEntropyLoss()
fixed_arc_path = os.path.join(config.output_path, config.arc_checkpoint)
with open(fixed_arc_path, "r") as f:
fixed_arc = json.load(f)
fixed_arc = utils.encode_tensor(fixed_arc, torch.device("cuda"))
genotypes = utils.parse_results(fixed_arc, n_nodes=4)
genotypes_dict = {i: genotypes for i in range(3)}
apply_fixed_architecture(model, fixed_arc_path)
param_size = utils.param_size(model, criterion, [3, 32, 32] if 'cifar' in config.dataset else [3, 224, 224])
if main_proc:
logger.info("Param size: %.6f", param_size)
logger.info("Genotype: %s", genotypes)
# change training hyper parameters according to cell type
if 'cifar' in config.dataset:
if param_size < 3.0:
config.weight_decay = 3e-4
config.drop_path_prob = 0.2
elif 3.0 < param_size < 3.5:
config.weight_decay = 3e-4
config.drop_path_prob = 0.3
else:
config.weight_decay = 5e-4
config.drop_path_prob = 0.3
if config.distributed:
apex.parallel.convert_syncbn_model(model)
model = DistributedDataParallel(model, delay_allreduce=True)
optimizer = torch.optim.SGD(model.parameters(), config.lr, momentum=config.momentum, weight_decay=config.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs, eta_min=1E-6)
best_top1 = best_top5 = 0.
for epoch in range(config.epochs):
drop_prob = config.drop_path_prob * epoch / config.epochs
if config.distributed:
model.module.drop_path_prob(drop_prob)
else:
model.drop_path_prob(drop_prob)
# training
if config.distributed:
train_sampler.set_epoch(epoch)
train(logger, config, train_loader, model, optimizer, criterion, epoch, main_proc)
# validation
top1, top5 = validate(logger, config, valid_loader, model, criterion, epoch, main_proc)
best_top1 = max(best_top1, top1)
best_top5 = max(best_top5, top5)
lr_scheduler.step()
logger.info("Final best Prec@1 = %.4f Prec@5 = %.4f", best_top1, best_top5)
if __name__ == "__main__":
main()
NGPUS=4
SGPU=0
EGPU=$[NGPUS+SGPU-1]
GPU_ID=`seq -s , $SGPU $EGPU`
CUDA_VISIBLE_DEVICES=$GPU_ID python -m torch.distributed.launch --nproc_per_node=$NGPUS retrain.py \
--dataset cifar10 --n_classes 10 --init_channels 36 --stem_multiplier 3 \
--arc_checkpoint 'epoch_31.json' \
--batch_size 128 --workers 1 --log_frequency 10 \
--world_size $NGPUS --weight_decay 5e-4 \
--distributed --dist_url 'tcp://127.0.0.1:26443' \
--lr 0.1 --warmup_epochs 0 --epochs 600 \
--cutout_length 16 --aux_weight 0.4 --drop_path_prob 0.3 \
--label_smooth 0.0 --mixup_alpha 0
NGPUS=4
SGPU=0
EGPU=$[NGPUS+SGPU-1]
GPU_ID=`seq -s , $SGPU $EGPU`
CUDA_VISIBLE_DEVICES=$GPU_ID python -m torch.distributed.launch --nproc_per_node=$NGPUS search.py \
--dataset cifar10 --n_classes 10 --init_channels 16 --stem_multiplier 3 \
--batch_size 64 --workers 1 --log_frequency 10 \
--distributed --world_size $NGPUS --dist_url 'tcp://127.0.0.1:23343' \
--regular_ratio 0.2 --regular_coeff 5 \
--loss_alpha 1 --loss_T 2 \
--w_lr 0.2 --alpha_lr 3e-4 --nasnet_lr 0.2 \
--w_weight_decay 0. --alpha_weight_decay 0. \
--share_module --interactive_type kl \
--warmup_epochs 2 --epochs 32
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import random
import time
import numpy as np
import torch
import torch.nn as nn
import utils
from config import SearchConfig
from datasets.cifar import get_search_datasets
from model import Model
from nni.nas.pytorch.cdarts import CdartsTrainer
if __name__ == "__main__":
config = SearchConfig()
main_proc = not config.distributed or config.local_rank == 0
if config.distributed:
torch.cuda.set_device(config.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method=config.dist_url,
rank=config.local_rank, world_size=config.world_size)
if main_proc:
os.makedirs(config.output_path, exist_ok=True)
if config.distributed:
torch.distributed.barrier()
logger = utils.get_logger(os.path.join(config.output_path, 'search.log'))
if main_proc:
config.print_params(logger.info)
utils.reset_seed(config.seed)
loaders, samplers = get_search_datasets(config)
model_small = Model(config.dataset, 8).cuda()
if config.share_module:
model_large = Model(config.dataset, 20, shared_modules=model_small.shared_modules).cuda()
else:
model_large = Model(config.dataset, 20).cuda()
criterion = nn.CrossEntropyLoss()
trainer = CdartsTrainer(model_small, model_large, criterion, loaders, samplers, logger,
config.regular_coeff, config.regular_ratio, config.warmup_epochs, config.fix_head,
config.epochs, config.steps_per_epoch, config.loss_alpha, config.loss_T, config.distributed,
config.log_frequency, config.grad_clip, config.interactive_type, config.output_path,
config.w_lr, config.w_momentum, config.w_weight_decay, config.alpha_lr, config.alpha_weight_decay,
config.nasnet_lr, config.local_rank, config.share_module)
trainer.train()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import random
from collections import namedtuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from genotypes import Genotype
from ops import PRIMITIVES
from nni.nas.pytorch.cdarts.utils import *
def get_logger(file_path):
""" Make python logger """
logger = logging.getLogger('cdarts')
log_format = '%(asctime)s | %(message)s'
formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
file_handler = logging.FileHandler(file_path)
file_handler.setFormatter(formatter)
# stream_handler = logging.StreamHandler()
# stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# logger.addHandler(stream_handler)
logger.setLevel(logging.INFO)
return logger
class CyclicIterator:
def __init__(self, loader, sampler, distributed):
self.loader = loader
self.sampler = sampler
self.epoch = 0
self.distributed = distributed
self._next_epoch()
def _next_epoch(self):
if self.distributed:
self.sampler.set_epoch(self.epoch)
self.iterator = iter(self.loader)
self.epoch += 1
def __len__(self):
return len(self.loader)
def __iter__(self):
return self
def __next__(self):
try:
return next(self.iterator)
except StopIteration:
self._next_epoch()
return next(self.iterator)
class CrossEntropyLabelSmooth(nn.Module):
def __init__(self, num_classes, epsilon):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, inputs, targets):
log_probs = self.logsoftmax(inputs)
targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (-targets * log_probs).mean(0).sum()
return loss
def parse_results(results, n_nodes):
concat = range(2, 2 + n_nodes)
normal_gene = []
reduction_gene = []
for i in range(n_nodes):
normal_node = []
reduction_node = []
for j in range(2 + i):
normal_key = 'normal_n{}_p{}'.format(i + 2, j)
reduction_key = 'reduce_n{}_p{}'.format(i + 2, j)
normal_op = results[normal_key].cpu().numpy()
reduction_op = results[reduction_key].cpu().numpy()
if sum(normal_op == 1):
normal_index = np.argmax(normal_op)
normal_node.append((PRIMITIVES[normal_index], j))
if sum(reduction_op == 1):
reduction_index = np.argmax(reduction_op)
reduction_node.append((PRIMITIVES[reduction_index], j))
normal_gene.append(normal_node)
reduction_gene.append(reduction_node)
genotypes = Genotype(normal=normal_gene, normal_concat=concat,
reduce=reduction_gene, reduce_concat=concat)
return genotypes
def param_size(model, loss_fn, input_size):
"""
Compute parameter size in MB
"""
x = torch.rand([2] + input_size).cuda()
y, _ = model(x)
target = torch.randint(model.n_classes, size=[2]).cuda()
loss = loss_fn(y, target)
loss.backward()
n_params = sum(np.prod(v.size()) for k, v in model.named_parameters() if not k.startswith('aux_head') and v.grad is not None)
return n_params / 1e6
def encode_tensor(data, device):
if isinstance(data, list):
if all(map(lambda o: isinstance(o, bool), data)):
return torch.tensor(data, dtype=torch.bool, device=device) # pylint: disable=not-callable
else:
return torch.tensor(data, dtype=torch.float, device=device) # pylint: disable=not-callable
if isinstance(data, dict):
return {k: encode_tensor(v, device) for k, v in data.items()}
return data
def reset_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import RegularizedDartsMutator, RegularizedMutatorParallel, DartsDiscreteMutator
from .trainer import CdartsTrainer
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from apex.parallel import DistributedDataParallel # pylint: disable=import-error
from nni.nas.pytorch.darts import DartsMutator # pylint: disable=wrong-import-order
from nni.nas.pytorch.mutables import LayerChoice # pylint: disable=wrong-import-order
from nni.nas.pytorch.mutator import Mutator # pylint: disable=wrong-import-order
class RegularizedDartsMutator(DartsMutator):
"""
This is :class:`~nni.nas.pytorch.darts.DartsMutator` basically, with two differences.
1. Choices can be cut (bypassed). This is done by ``cut_choices``. Cutted choices will not be used in
forward pass and thus consumes no memory.
2. Regularization on choices, to prevent the mutator from overfitting on some choices.
"""
def reset(self):
"""
Warnings
--------
Renamed :func:`~reset_with_loss` to return regularization loss on reset.
"""
raise ValueError("You should probably call `reset_with_loss`.")
def cut_choices(self, cut_num=2):
"""
Cut the choices with the smallest weights.
``cut_num`` should be the accumulative number of cutting, e.g., if first time cutting
is 2, the second time should be 4 to cut another two.
Parameters
----------
cut_num : int
Number of choices to cut, so far.
Warnings
--------
Though the parameters are set to :math:`-\infty` to be bypassed, they will still receive gradient of 0,
which introduced ``nan`` problem when calling ``optimizer.step()``. To solve this issue, a simple way is to
reset nan to :math:`-\infty` each time after the parameters are updated.
"""
# `cut_choices` is implemented but not used in current implementation of CdartsTrainer
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
_, idx = torch.topk(-self.choices[mutable.key], cut_num)
with torch.no_grad():
for i in idx:
self.choices[mutable.key][i] = -float("inf")
def reset_with_loss(self):
"""
Resample and return loss. If loss is 0, to avoid device issue, it will return ``None``.
Currently loss penalty are proportional to the L1-norm of parameters corresponding
to modules if their type name contains certain substrings. These substrings include: ``poolwithoutbn``,
``identity``, ``dilconv``.
"""
self._cache, reg_loss = self.sample_search()
return reg_loss
def sample_search(self):
result = super().sample_search()
loss = []
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
def need_reg(choice):
return any(t in str(type(choice)).lower() for t in ["poolwithoutbn", "identity", "dilconv"])
for i, choice in enumerate(mutable.choices):
if need_reg(choice):
norm = torch.abs(self.choices[mutable.key][i])
if norm < 1E10:
loss.append(norm)
if not loss:
return result, None
return result, sum(loss)
def export(self, logger=None):
"""
Export an architecture with logger. Genotype will be printed with logger.
Returns
-------
dict
A mapping from mutable keys to decisions.
"""
result = self.sample_final()
if hasattr(self.model, "plot_genotype") and logger is not None:
genotypes = self.model.plot_genotype(result, logger)
return result, genotypes
class RegularizedMutatorParallel(DistributedDataParallel):
"""
Parallelize :class:`~RegularizedDartsMutator`.
This makes :func:`~RegularizedDartsMutator.reset_with_loss` method parallelized,
also allowing :func:`~RegularizedDartsMutator.cut_choices` and :func:`~RegularizedDartsMutator.export`
to be easily accessible.
"""
def reset_with_loss(self):
"""
Parallelized :func:`~RegularizedDartsMutator.reset_with_loss`.
"""
result = self.module.reset_with_loss()
self.callback_queued = False
return result
def cut_choices(self, *args, **kwargs):
"""
Parallelized :func:`~RegularizedDartsMutator.cut_choices`.
"""
self.module.cut_choices(*args, **kwargs)
def export(self, logger):
"""
Parallelized :func:`~RegularizedDartsMutator.export`.
"""
return self.module.export(logger)
class DartsDiscreteMutator(Mutator):
"""
A mutator that applies the final sampling result of a parent mutator on another model to train.
"""
def __init__(self, model, parent_mutator):
"""
Initialization.
Parameters
----------
model : nn.Module
The model to apply the mutator.
parent_mutator : Mutator
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
"""
super().__init__(model)
self.__dict__["parent_mutator"] = parent_mutator # avoid parameters to be included
def sample_search(self):
return self.parent_mutator.sample_final()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment