import torch import logging import torch.nn as nn import torch.nn.functional as F from argparse import ArgumentParser from torchvision import transforms from torchvision.datasets import CIFAR10 from nni.nas.pytorch import mutables from nni.algorithms.nas.pytorch import enas from utils import accuracy, reward_accuracy from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint, LRSchedulerCallback) from nni.nas.pytorch.search_space_zoo import ENASMacroLayer from nni.nas.pytorch.search_space_zoo import ENASMacroGeneralModel logger = logging.getLogger('nni') def get_dataset(cls): MEAN = [0.49139968, 0.48215827, 0.44653124] STD = [0.24703233, 0.24348505, 0.26158768] transf = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip() ] normalize = [ transforms.ToTensor(), transforms.Normalize(MEAN, STD) ] train_transform = transforms.Compose(transf + normalize) valid_transform = transforms.Compose(normalize) if cls == "cifar10": dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform) dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform) else: raise NotImplementedError return dataset_train, dataset_valid class FactorizedReduce(nn.Module): def __init__(self, C_in, C_out, affine=False): super().__init__() self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.bn = nn.BatchNorm2d(C_out, affine=affine) def forward(self, x): out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) out = self.bn(out) return out if __name__ == "__main__": parser = ArgumentParser("enas") parser.add_argument("--batch-size", default=128, type=int) parser.add_argument("--log-frequency", default=10, type=int) parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") parser.add_argument("--visualization", default=False, action="store_true") args = parser.parse_args() dataset_train, dataset_valid = get_dataset("cifar10") model = ENASMacroGeneralModel() num_epochs = args.epochs or 310 mutator = None criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) trainer = enas.EnasTrainer(model, loss=criterion, metrics=accuracy, reward_function=reward_accuracy, optimizer=optimizer, callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")], batch_size=args.batch_size, num_epochs=num_epochs, dataset_train=dataset_train, dataset_valid=dataset_valid, log_frequency=args.log_frequency, mutator=mutator) if args.visualization: trainer.enable_visualization() trainer.train()