Commit 6c1fe5c8 authored by Yuge Zhang's avatar Yuge Zhang Committed by QuanluZhang
Browse files

Fix minor issues in DARTS and add a simple CIFAR10 example (with random mutator) (#1776)

parent 9aed500d
......@@ -51,7 +51,7 @@ class Node(nn.Module):
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
],
key=choice_keys[-1]))
self.drop_path = ops.DropPath_()
self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
def forward(self, prev_nodes):
......@@ -153,5 +153,5 @@ class CNN(nn.Module):
def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath_):
if isinstance(module, ops.DropPath):
module.p = p
......@@ -2,10 +2,10 @@ import torch
import torch.nn as nn
class DropPath_(nn.Module):
class DropPath(nn.Module):
def __init__(self, p=0.):
"""
DropPath is inplace module.
Drop path with probability.
Parameters
----------
......@@ -15,15 +15,12 @@ class DropPath_(nn.Module):
super().__init__()
self.p = p
def extra_repr(self):
return 'p={}, inplace'.format(self.p)
def forward(self, x):
if self.training and self.p > 0.:
keep_prob = 1. - self.p
# per data point mask
mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob)
x.div_(keep_prob).mul_(mask)
return x / keep_prob * mask
return x
......
......@@ -33,7 +33,7 @@ def train(config, train_loader, model, optimizer, criterion, epoch):
losses = AverageMeter("losses")
cur_step = epoch * len(train_loader)
cur_lr = optimizer.param_groups[0]['lr']
cur_lr = optimizer.param_groups[0]["lr"]
logger.info("Epoch %d LR %.6f", epoch, cur_lr)
writer.add_scalar("lr", cur_lr, global_step=cur_step)
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
from nni.nas.pytorch.darts import DartsTrainer
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = LayerChoice([nn.Conv2d(3, 6, 3, padding=1), nn.Conv2d(3, 6, 5, padding=2)])
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = LayerChoice([nn.Conv2d(6, 16, 3, padding=1), nn.Conv2d(6, 16, 5, padding=2)])
self.conv3 = nn.Conv2d(16, 16, 1)
self.skipconnect = InputChoice(n_candidates=1)
self.bn = nn.BatchNorm2d(16)
self.gap = nn.AdaptiveAvgPool2d(4)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
bs = x.size(0)
x = self.pool(F.relu(self.conv1(x)))
x0 = F.relu(self.conv2(x))
x1 = F.relu(self.conv3(x0))
x0 = self.skipconnect([x0])
if x0 is not None:
x1 += x0
x = self.pool(self.bn(x1))
x = self.gap(x).view(bs, -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def accuracy(output, target):
batch_size = target.size(0)
_, predicted = torch.max(output.data, 1)
return {"acc1": (predicted == target).sum().item() / batch_size}
if __name__ == "__main__":
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_train = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
dataset_valid = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
trainer = DartsTrainer(net,
loss=criterion,
metrics=accuracy,
optimizer=optimizer,
num_epochs=2,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=64,
log_frequency=10)
trainer.train()
trainer.export("checkpoint.json")
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -5,6 +7,8 @@ import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
_logger = logging.getLogger(__name__)
class DartsMutator(Mutator):
def __init__(self, model):
......@@ -36,9 +40,14 @@ class DartsMutator(Mutator):
edges_max[mutable.key] = max_val
result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool()
for mutable in self.mutables:
if isinstance(mutable, InputChoice):
weights = torch.tensor([edges_max.get(src_key, 0.) for src_key in mutable.choose_from]) # pylint: disable=not-callable
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen or mutable.n_candidates)
if isinstance(mutable, InputChoice) and mutable.n_chosen is not None:
weights = []
for src_key in mutable.choose_from:
if src_key not in edges_max:
_logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key)
weights.append(edges_max.get(src_key, 0.))
weights = torch.tensor(weights) # pylint: disable=not-callable
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen)
selected_multihot = []
for i, src_key in enumerate(mutable.choose_from):
if i not in topk_edge_indices and src_key in result:
......
......@@ -62,7 +62,8 @@ class Mutable(nn.Module):
def _check_built(self):
if not hasattr(self, "mutator"):
raise ValueError(
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
......@@ -179,7 +180,8 @@ class InputChoice(Mutable):
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), "Optional input list must be a list"
assert isinstance(optional_input_list, list), \
"Optional input list must be a list, not a {}.".format(type(optional_input_list))
assert len(optional_inputs) == self.n_candidates, \
"Length of the input list must be equal to number of candidates."
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
......
......@@ -76,7 +76,8 @@ class Mutator(BaseMutator):
return op(*inputs)
mask = self._get_decision(mutable)
assert len(mask) == len(mutable.choices)
assert len(mask) == len(mutable.choices), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable.choices))
out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask)
return self._tensor_reduction(mutable.reduction, out), mask
......@@ -98,7 +99,8 @@ class Mutator(BaseMutator):
tuple of torch.Tensor and torch.Tensor
"""
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates
assert len(mask) == mutable.n_candidates, \
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask
......
from .mutator import RandomMutator
\ No newline at end of file
import torch
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
class RandomMutator(Mutator):
def sample_search(self):
result = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
gen_index = torch.randint(high=mutable.length, size=(1, ))
result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool()
elif isinstance(mutable, InputChoice):
if mutable.n_chosen is None:
result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool()
else:
perm = torch.randperm(mutable.n_candidates)
mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)]
result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable
return result
def sample_final(self):
return self.sample_search()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment