Unverified Commit f5dbcd81 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix layer choice on IT and deprecate "choices" and "length" (#2386)

parent 140eb682
......@@ -92,8 +92,8 @@ class ClassicMutator(Mutator):
The list for corresponding search space.
"""
# doesn't support multihot for layer choice yet
onehot_list = [False] * mutable.length
assert 0 <= idx < mutable.length and search_space_item[idx] == value, \
onehot_list = [False] * len(mutable)
assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value)
onehot_list[idx] = True
return torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable
......
......@@ -61,7 +61,7 @@ class DartsMutator(Mutator):
if isinstance(mutable, LayerChoice):
max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0)
edges_max[mutable.key] = max_val
result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool()
result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool()
for mutable in self.mutables:
if isinstance(mutable, InputChoice):
if mutable.n_chosen is not None:
......
......@@ -86,15 +86,15 @@ class EnasMutator(Mutator):
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = mutable.length
assert self.max_layer_choice == mutable.length, \
self.max_layer_choice = len(mutable)
assert self.max_layer_choice == len(mutable), \
"ENAS mutator requires all layer choice have the same number of candidates."
# We are judging by keys and module types to add biases to layer choices. Needs refactor.
if "reduce" in mutable.key:
def is_conv(choice):
return "conv" in str(type(choice)).lower()
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
for choice in mutable.choices])
for choice in mutable])
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import logging
import warnings
from collections import OrderedDict
import torch.nn as nn
......@@ -140,9 +141,12 @@ class LayerChoice(Mutable):
Attributes
----------
length : int
Number of ops to choose from.
names: list of str
Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended.
names : list of str
Names of candidates.
choices : list of Module
Deprecated. A list of all candidate modules in the layer choice module.
``list(layer_choice)`` is recommended, which will serve the same purpose.
Notes
-----
......@@ -156,30 +160,65 @@ class LayerChoice(Mutable):
("conv7x7", nn.Conv2d(7, 16, 128))
]))
Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
super().__init__(key=key)
self.length = len(op_candidates)
self.choices = []
self.names = []
if isinstance(op_candidates, OrderedDict):
for name, module in op_candidates.items():
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.add_module(name, module)
self.choices.append(module)
self.names.append(name)
elif isinstance(op_candidates, list):
for i, module in enumerate(op_candidates):
self.add_module(str(i), module)
self.choices.append(module)
self.names.append(str(i))
else:
raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates)))
self.reduction = reduction
self.return_mask = return_mask
def __getitem__(self, idx):
if isinstance(idx, str):
return self._modules[idx]
return list(self)[idx]
def __setitem__(self, idx, module):
key = idx if isinstance(idx, str) else self.names[idx]
return setattr(self, key, module)
def __delitem__(self, idx):
if isinstance(idx, slice):
for key in self.names[idx]:
delattr(self, key)
else:
if isinstance(idx, str):
key, idx = idx, self.names.index(idx)
else:
key = self.names[idx]
delattr(self, key)
del self.names[idx]
@property
def length(self):
warnings.warn("layer_choice.length is deprecated. Use `len(layer_choice)` instead.", DeprecationWarning)
return len(self)
def __len__(self):
return len(self.names)
def __iter__(self):
return map(lambda name: self._modules[name], self.names)
@property
def choices(self):
warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", DeprecationWarning)
return list(self)
def forward(self, *args, **kwargs):
"""
Returns
......
......@@ -150,16 +150,16 @@ class Mutator(BaseMutator):
"""
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction,
[op(*args, **kwargs) for op in mutable.choices]), \
torch.ones(mutable.length)
[op(*args, **kwargs) for op in mutable]), \
torch.ones(len(mutable))
def _map_fn(op, args, kwargs):
return op(*args, **kwargs)
mask = self._get_decision(mutable)
assert len(mask) == len(mutable.choices), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable.choices))
out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable.choices], mask)
assert len(mask) == len(mutable), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable))
out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list):
......
......@@ -32,7 +32,7 @@ class PdartsMutator(DartsMutator):
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
switches = self.switches.get(mutable.key, [True for j in range(mutable.length)])
switches = self.switches.get(mutable.key, [True for j in range(len(mutable))])
choices = self.choices[mutable.key]
operations_count = np.sum(switches)
......@@ -48,12 +48,12 @@ class PdartsMutator(DartsMutator):
if isinstance(module, LayerChoice):
switches = self.switches.get(module.key)
choices = self.choices[module.key]
if len(module.choices) > len(choices):
if len(module) > len(choices):
# from last to first, so that it won't effect previous indexes after removed one.
for index in range(len(switches)-1, -1, -1):
if switches[index] == False:
del(module.choices[index])
module.length -= 1
del module[index]
assert len(module) <= len(choices), "Failed to remove dropped choices."
def sample_final(self):
results = super().sample_final()
......
......@@ -53,15 +53,15 @@ class MixedOp(nn.Module):
A LayerChoice in user model
"""
super(MixedOp, self).__init__()
self.ap_path_alpha = nn.Parameter(torch.Tensor(mutable.length))
self.ap_path_wb = nn.Parameter(torch.Tensor(mutable.length))
self.ap_path_alpha = nn.Parameter(torch.Tensor(len(mutable)))
self.ap_path_wb = nn.Parameter(torch.Tensor(len(mutable)))
self.ap_path_alpha.requires_grad = False
self.ap_path_wb.requires_grad = False
self.active_index = [0]
self.inactive_index = None
self.log_prob = None
self.current_prob_over_ops = None
self.n_choices = mutable.length
self.n_choices = len(mutable)
def get_ap_path_alpha(self):
return self.ap_path_alpha
......@@ -120,8 +120,8 @@ class MixedOp(nn.Module):
return binary_grads
return backward
output = ArchGradientFunction.apply(
x, self.ap_path_wb, run_function(mutable.key, mutable.choices, self.active_index[0]),
backward_function(mutable.key, mutable.choices, self.active_index[0], self.ap_path_wb))
x, self.ap_path_wb, run_function(mutable.key, list(mutable), self.active_index[0]),
backward_function(mutable.key, list(mutable), self.active_index[0], self.ap_path_wb))
else:
output = self.active_op(mutable)(x)
return output
......@@ -164,7 +164,7 @@ class MixedOp(nn.Module):
PyTorch module
the chosen operation
"""
return mutable.choices[self.active_index[0]]
return mutable[self.active_index[0]]
@property
def active_op_index(self):
......@@ -222,12 +222,12 @@ class MixedOp(nn.Module):
sample = torch.multinomial(probs, 1)[0].item()
self.active_index = [sample]
self.inactive_index = [_i for _i in range(0, sample)] + \
[_i for _i in range(sample + 1, len(mutable.choices))]
[_i for _i in range(sample + 1, len(mutable))]
self.log_prob = torch.log(probs[sample])
self.current_prob_over_ops = probs
self.ap_path_wb.data[sample] = 1.0
# avoid over-regularization
for choice in mutable.choices:
for choice in mutable:
for _, param in choice.named_parameters():
param.grad = None
......@@ -430,8 +430,8 @@ class ProxylessNasMutator(BaseMutator):
involved_index = mixed_op.active_index
for i in range(mixed_op.n_choices):
if i not in involved_index:
unused[i] = mutable.choices[i]
mutable.choices[i] = None
unused[i] = mutable[i]
mutable[i] = None
self._unused_modules.append(unused)
def unused_modules_back(self):
......@@ -442,7 +442,7 @@ class ProxylessNasMutator(BaseMutator):
return
for m, unused in zip(self.mutable_list, self._unused_modules):
for i in unused:
m.choices[i] = unused[i]
m[i] = unused[i]
self._unused_modules = None
def arch_requires_grad(self):
......@@ -474,5 +474,5 @@ class ProxylessNasMutator(BaseMutator):
assert isinstance(mutable, LayerChoice)
index, _ = mutable.registered_module.chosen_index
# pylint: disable=not-callable
result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=mutable.length).view(-1).bool()
result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=len(mutable)).view(-1).bool()
return result
......@@ -18,8 +18,8 @@ class RandomMutator(Mutator):
result = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
gen_index = torch.randint(high=mutable.length, size=(1, ))
result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool()
gen_index = torch.randint(high=len(mutable), size=(1, ))
result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool()
elif isinstance(mutable, InputChoice):
if mutable.n_chosen is None:
result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .layer_choice_only import LayerChoiceOnlySearchSpace
from .mutable_scope import SpaceWithMutableScope
from .naive import NaiveSearchSpace
from .nested import NestedSpace
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutables import LayerChoice
class LayerChoiceOnlySearchSpace(nn.Module):
def __init__(self, test_case):
super().__init__()
self.test_case = test_case
self.conv1 = LayerChoice([nn.Conv2d(3, 6, 3, padding=1), nn.Conv2d(3, 6, 5, padding=2)])
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = LayerChoice([nn.Conv2d(6, 16, 3, padding=1), nn.Conv2d(6, 16, 5, padding=2)],
return_mask=True)
self.conv3 = nn.Conv2d(16, 16, 1)
self.bn = nn.BatchNorm2d(16)
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 10)
def forward(self, x):
bs = x.size(0)
x = self.pool(F.relu(self.conv1(x)))
x0, mask = self.conv2(x)
self.test_case.assertEqual(mask.size(), torch.Size([2]))
x1 = F.relu(self.conv3(x0))
x = self.pool(self.bn(x1))
self.test_case.assertEqual(mask.size(), torch.Size([2]))
x = self.gap(x).view(bs, -1)
x = self.fc(x)
return x
......@@ -3,6 +3,7 @@
import importlib
import os
import sys
from collections import OrderedDict
from unittest import TestCase, main
import torch
......@@ -11,6 +12,7 @@ from nni.nas.pytorch.classic_nas import get_and_apply_next_architecture
from nni.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.enas import EnasMutator
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.mutables import LayerChoice
from nni.nas.pytorch.random import RandomMutator
from nni.nas.pytorch.utils import _reset_global_mutable_counting
......@@ -101,6 +103,43 @@ class NasTestCase(TestCase):
get_and_apply_next_architecture(model)
self.iterative_sample_and_forward(model)
def test_proxylessnas(self):
model = self.model_module.LayerChoiceOnlySearchSpace(self)
get_and_apply_next_architecture(model)
self.iterative_sample_and_forward(model)
def test_layer_choice(self):
for i in range(2):
for j in range(2):
if j == 0:
# test number
layer_choice = LayerChoice([nn.Conv2d(3, 3, 3), nn.Conv2d(3, 5, 3), nn.Conv2d(3, 6, 3)])
else:
# test ordered dict
layer_choice = LayerChoice(OrderedDict([
("conv1", nn.Conv2d(3, 3, 3)),
("conv2", nn.Conv2d(3, 5, 3)),
("conv3", nn.Conv2d(3, 6, 3))
]))
if i == 0:
# test modify
self.assertEqual(len(layer_choice.choices), 3)
layer_choice[1] = nn.Conv2d(3, 4, 3)
self.assertEqual(layer_choice[1].out_channels, 4)
self.assertEqual(len(layer_choice[0:2]), 2)
if j > 0:
layer_choice["conv3"] = nn.Conv2d(3, 7, 3)
self.assertEqual(layer_choice[-1].out_channels, 7)
if i == 1:
# test delete
del layer_choice[1]
self.assertEqual(len(layer_choice), 2)
self.assertEqual(len(list(layer_choice)), 2)
self.assertEqual(layer_choice.names, ["conv1", "conv3"] if j > 0 else ["0", "2"])
if j > 0:
del layer_choice["conv1"]
self.assertEqual(len(layer_choice), 1)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment