Unverified Commit f5dbcd81 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix layer choice on IT and deprecate "choices" and "length" (#2386)

parent 140eb682
...@@ -92,8 +92,8 @@ class ClassicMutator(Mutator): ...@@ -92,8 +92,8 @@ class ClassicMutator(Mutator):
The list for corresponding search space. The list for corresponding search space.
""" """
# doesn't support multihot for layer choice yet # doesn't support multihot for layer choice yet
onehot_list = [False] * mutable.length onehot_list = [False] * len(mutable)
assert 0 <= idx < mutable.length and search_space_item[idx] == value, \ assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value) "Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value)
onehot_list[idx] = True onehot_list[idx] = True
return torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable return torch.tensor(onehot_list, dtype=torch.bool) # pylint: disable=not-callable
......
...@@ -61,7 +61,7 @@ class DartsMutator(Mutator): ...@@ -61,7 +61,7 @@ class DartsMutator(Mutator):
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0) max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0)
edges_max[mutable.key] = max_val edges_max[mutable.key] = max_val
result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool() result[mutable.key] = F.one_hot(index, num_classes=len(mutable)).view(-1).bool()
for mutable in self.mutables: for mutable in self.mutables:
if isinstance(mutable, InputChoice): if isinstance(mutable, InputChoice):
if mutable.n_chosen is not None: if mutable.n_chosen is not None:
......
...@@ -86,15 +86,15 @@ class EnasMutator(Mutator): ...@@ -86,15 +86,15 @@ class EnasMutator(Mutator):
for mutable in self.mutables: for mutable in self.mutables:
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0: if self.max_layer_choice == 0:
self.max_layer_choice = mutable.length self.max_layer_choice = len(mutable)
assert self.max_layer_choice == mutable.length, \ assert self.max_layer_choice == len(mutable), \
"ENAS mutator requires all layer choice have the same number of candidates." "ENAS mutator requires all layer choice have the same number of candidates."
# We are judging by keys and module types to add biases to layer choices. Needs refactor. # We are judging by keys and module types to add biases to layer choices. Needs refactor.
if "reduce" in mutable.key: if "reduce" in mutable.key:
def is_conv(choice): def is_conv(choice):
return "conv" in str(type(choice)).lower() return "conv" in str(type(choice)).lower()
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
for choice in mutable.choices]) for choice in mutable])
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False) self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size) self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import warnings
from collections import OrderedDict from collections import OrderedDict
import torch.nn as nn import torch.nn as nn
...@@ -140,9 +141,12 @@ class LayerChoice(Mutable): ...@@ -140,9 +141,12 @@ class LayerChoice(Mutable):
Attributes Attributes
---------- ----------
length : int length : int
Number of ops to choose from. Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended.
names: list of str names : list of str
Names of candidates. Names of candidates.
choices : list of Module
Deprecated. A list of all candidate modules in the layer choice module.
``list(layer_choice)`` is recommended, which will serve the same purpose.
Notes Notes
----- -----
...@@ -156,30 +160,65 @@ class LayerChoice(Mutable): ...@@ -156,30 +160,65 @@ class LayerChoice(Mutable):
("conv7x7", nn.Conv2d(7, 16, 128)) ("conv7x7", nn.Conv2d(7, 16, 128))
])) ]))
Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
""" """
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None): def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
super().__init__(key=key) super().__init__(key=key)
self.length = len(op_candidates)
self.choices = []
self.names = [] self.names = []
if isinstance(op_candidates, OrderedDict): if isinstance(op_candidates, OrderedDict):
for name, module in op_candidates.items(): for name, module in op_candidates.items():
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \ assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name) "Please don't use a reserved name '{}' for your module.".format(name)
self.add_module(name, module) self.add_module(name, module)
self.choices.append(module)
self.names.append(name) self.names.append(name)
elif isinstance(op_candidates, list): elif isinstance(op_candidates, list):
for i, module in enumerate(op_candidates): for i, module in enumerate(op_candidates):
self.add_module(str(i), module) self.add_module(str(i), module)
self.choices.append(module)
self.names.append(str(i)) self.names.append(str(i))
else: else:
raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates))) raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates)))
self.reduction = reduction self.reduction = reduction
self.return_mask = return_mask self.return_mask = return_mask
def __getitem__(self, idx):
if isinstance(idx, str):
return self._modules[idx]
return list(self)[idx]
def __setitem__(self, idx, module):
key = idx if isinstance(idx, str) else self.names[idx]
return setattr(self, key, module)
def __delitem__(self, idx):
if isinstance(idx, slice):
for key in self.names[idx]:
delattr(self, key)
else:
if isinstance(idx, str):
key, idx = idx, self.names.index(idx)
else:
key = self.names[idx]
delattr(self, key)
del self.names[idx]
@property
def length(self):
warnings.warn("layer_choice.length is deprecated. Use `len(layer_choice)` instead.", DeprecationWarning)
return len(self)
def __len__(self):
return len(self.names)
def __iter__(self):
return map(lambda name: self._modules[name], self.names)
@property
def choices(self):
warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", DeprecationWarning)
return list(self)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
""" """
Returns Returns
......
...@@ -150,16 +150,16 @@ class Mutator(BaseMutator): ...@@ -150,16 +150,16 @@ class Mutator(BaseMutator):
""" """
if self._connect_all: if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction, return self._all_connect_tensor_reduction(mutable.reduction,
[op(*args, **kwargs) for op in mutable.choices]), \ [op(*args, **kwargs) for op in mutable]), \
torch.ones(mutable.length) torch.ones(len(mutable))
def _map_fn(op, args, kwargs): def _map_fn(op, args, kwargs):
return op(*args, **kwargs) return op(*args, **kwargs)
mask = self._get_decision(mutable) mask = self._get_decision(mutable)
assert len(mask) == len(mutable.choices), \ assert len(mask) == len(mutable), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable.choices)) "Invalid mask, expected {} to be of length {}.".format(mask, len(mutable))
out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable.choices], mask) out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask)
return self._tensor_reduction(mutable.reduction, out), mask return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list): def on_forward_input_choice(self, mutable, tensor_list):
......
...@@ -32,7 +32,7 @@ class PdartsMutator(DartsMutator): ...@@ -32,7 +32,7 @@ class PdartsMutator(DartsMutator):
for mutable in self.mutables: for mutable in self.mutables:
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
switches = self.switches.get(mutable.key, [True for j in range(mutable.length)]) switches = self.switches.get(mutable.key, [True for j in range(len(mutable))])
choices = self.choices[mutable.key] choices = self.choices[mutable.key]
operations_count = np.sum(switches) operations_count = np.sum(switches)
...@@ -48,12 +48,12 @@ class PdartsMutator(DartsMutator): ...@@ -48,12 +48,12 @@ class PdartsMutator(DartsMutator):
if isinstance(module, LayerChoice): if isinstance(module, LayerChoice):
switches = self.switches.get(module.key) switches = self.switches.get(module.key)
choices = self.choices[module.key] choices = self.choices[module.key]
if len(module.choices) > len(choices): if len(module) > len(choices):
# from last to first, so that it won't effect previous indexes after removed one. # from last to first, so that it won't effect previous indexes after removed one.
for index in range(len(switches)-1, -1, -1): for index in range(len(switches)-1, -1, -1):
if switches[index] == False: if switches[index] == False:
del(module.choices[index]) del module[index]
module.length -= 1 assert len(module) <= len(choices), "Failed to remove dropped choices."
def sample_final(self): def sample_final(self):
results = super().sample_final() results = super().sample_final()
......
...@@ -53,15 +53,15 @@ class MixedOp(nn.Module): ...@@ -53,15 +53,15 @@ class MixedOp(nn.Module):
A LayerChoice in user model A LayerChoice in user model
""" """
super(MixedOp, self).__init__() super(MixedOp, self).__init__()
self.ap_path_alpha = nn.Parameter(torch.Tensor(mutable.length)) self.ap_path_alpha = nn.Parameter(torch.Tensor(len(mutable)))
self.ap_path_wb = nn.Parameter(torch.Tensor(mutable.length)) self.ap_path_wb = nn.Parameter(torch.Tensor(len(mutable)))
self.ap_path_alpha.requires_grad = False self.ap_path_alpha.requires_grad = False
self.ap_path_wb.requires_grad = False self.ap_path_wb.requires_grad = False
self.active_index = [0] self.active_index = [0]
self.inactive_index = None self.inactive_index = None
self.log_prob = None self.log_prob = None
self.current_prob_over_ops = None self.current_prob_over_ops = None
self.n_choices = mutable.length self.n_choices = len(mutable)
def get_ap_path_alpha(self): def get_ap_path_alpha(self):
return self.ap_path_alpha return self.ap_path_alpha
...@@ -120,8 +120,8 @@ class MixedOp(nn.Module): ...@@ -120,8 +120,8 @@ class MixedOp(nn.Module):
return binary_grads return binary_grads
return backward return backward
output = ArchGradientFunction.apply( output = ArchGradientFunction.apply(
x, self.ap_path_wb, run_function(mutable.key, mutable.choices, self.active_index[0]), x, self.ap_path_wb, run_function(mutable.key, list(mutable), self.active_index[0]),
backward_function(mutable.key, mutable.choices, self.active_index[0], self.ap_path_wb)) backward_function(mutable.key, list(mutable), self.active_index[0], self.ap_path_wb))
else: else:
output = self.active_op(mutable)(x) output = self.active_op(mutable)(x)
return output return output
...@@ -164,7 +164,7 @@ class MixedOp(nn.Module): ...@@ -164,7 +164,7 @@ class MixedOp(nn.Module):
PyTorch module PyTorch module
the chosen operation the chosen operation
""" """
return mutable.choices[self.active_index[0]] return mutable[self.active_index[0]]
@property @property
def active_op_index(self): def active_op_index(self):
...@@ -222,12 +222,12 @@ class MixedOp(nn.Module): ...@@ -222,12 +222,12 @@ class MixedOp(nn.Module):
sample = torch.multinomial(probs, 1)[0].item() sample = torch.multinomial(probs, 1)[0].item()
self.active_index = [sample] self.active_index = [sample]
self.inactive_index = [_i for _i in range(0, sample)] + \ self.inactive_index = [_i for _i in range(0, sample)] + \
[_i for _i in range(sample + 1, len(mutable.choices))] [_i for _i in range(sample + 1, len(mutable))]
self.log_prob = torch.log(probs[sample]) self.log_prob = torch.log(probs[sample])
self.current_prob_over_ops = probs self.current_prob_over_ops = probs
self.ap_path_wb.data[sample] = 1.0 self.ap_path_wb.data[sample] = 1.0
# avoid over-regularization # avoid over-regularization
for choice in mutable.choices: for choice in mutable:
for _, param in choice.named_parameters(): for _, param in choice.named_parameters():
param.grad = None param.grad = None
...@@ -430,8 +430,8 @@ class ProxylessNasMutator(BaseMutator): ...@@ -430,8 +430,8 @@ class ProxylessNasMutator(BaseMutator):
involved_index = mixed_op.active_index involved_index = mixed_op.active_index
for i in range(mixed_op.n_choices): for i in range(mixed_op.n_choices):
if i not in involved_index: if i not in involved_index:
unused[i] = mutable.choices[i] unused[i] = mutable[i]
mutable.choices[i] = None mutable[i] = None
self._unused_modules.append(unused) self._unused_modules.append(unused)
def unused_modules_back(self): def unused_modules_back(self):
...@@ -442,7 +442,7 @@ class ProxylessNasMutator(BaseMutator): ...@@ -442,7 +442,7 @@ class ProxylessNasMutator(BaseMutator):
return return
for m, unused in zip(self.mutable_list, self._unused_modules): for m, unused in zip(self.mutable_list, self._unused_modules):
for i in unused: for i in unused:
m.choices[i] = unused[i] m[i] = unused[i]
self._unused_modules = None self._unused_modules = None
def arch_requires_grad(self): def arch_requires_grad(self):
...@@ -474,5 +474,5 @@ class ProxylessNasMutator(BaseMutator): ...@@ -474,5 +474,5 @@ class ProxylessNasMutator(BaseMutator):
assert isinstance(mutable, LayerChoice) assert isinstance(mutable, LayerChoice)
index, _ = mutable.registered_module.chosen_index index, _ = mutable.registered_module.chosen_index
# pylint: disable=not-callable # pylint: disable=not-callable
result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=mutable.length).view(-1).bool() result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=len(mutable)).view(-1).bool()
return result return result
...@@ -18,8 +18,8 @@ class RandomMutator(Mutator): ...@@ -18,8 +18,8 @@ class RandomMutator(Mutator):
result = dict() result = dict()
for mutable in self.mutables: for mutable in self.mutables:
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
gen_index = torch.randint(high=mutable.length, size=(1, )) gen_index = torch.randint(high=len(mutable), size=(1, ))
result[mutable.key] = F.one_hot(gen_index, num_classes=mutable.length).view(-1).bool() result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool()
elif isinstance(mutable, InputChoice): elif isinstance(mutable, InputChoice):
if mutable.n_chosen is None: if mutable.n_chosen is None:
result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool() result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool()
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .layer_choice_only import LayerChoiceOnlySearchSpace
from .mutable_scope import SpaceWithMutableScope from .mutable_scope import SpaceWithMutableScope
from .naive import NaiveSearchSpace from .naive import NaiveSearchSpace
from .nested import NestedSpace from .nested import NestedSpace
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutables import LayerChoice
class LayerChoiceOnlySearchSpace(nn.Module):
def __init__(self, test_case):
super().__init__()
self.test_case = test_case
self.conv1 = LayerChoice([nn.Conv2d(3, 6, 3, padding=1), nn.Conv2d(3, 6, 5, padding=2)])
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = LayerChoice([nn.Conv2d(6, 16, 3, padding=1), nn.Conv2d(6, 16, 5, padding=2)],
return_mask=True)
self.conv3 = nn.Conv2d(16, 16, 1)
self.bn = nn.BatchNorm2d(16)
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 10)
def forward(self, x):
bs = x.size(0)
x = self.pool(F.relu(self.conv1(x)))
x0, mask = self.conv2(x)
self.test_case.assertEqual(mask.size(), torch.Size([2]))
x1 = F.relu(self.conv3(x0))
x = self.pool(self.bn(x1))
self.test_case.assertEqual(mask.size(), torch.Size([2]))
x = self.gap(x).view(bs, -1)
x = self.fc(x)
return x
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import importlib import importlib
import os import os
import sys import sys
from collections import OrderedDict
from unittest import TestCase, main from unittest import TestCase, main
import torch import torch
...@@ -11,6 +12,7 @@ from nni.nas.pytorch.classic_nas import get_and_apply_next_architecture ...@@ -11,6 +12,7 @@ from nni.nas.pytorch.classic_nas import get_and_apply_next_architecture
from nni.nas.pytorch.darts import DartsMutator from nni.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.enas import EnasMutator from nni.nas.pytorch.enas import EnasMutator
from nni.nas.pytorch.fixed import apply_fixed_architecture from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.mutables import LayerChoice
from nni.nas.pytorch.random import RandomMutator from nni.nas.pytorch.random import RandomMutator
from nni.nas.pytorch.utils import _reset_global_mutable_counting from nni.nas.pytorch.utils import _reset_global_mutable_counting
...@@ -101,6 +103,43 @@ class NasTestCase(TestCase): ...@@ -101,6 +103,43 @@ class NasTestCase(TestCase):
get_and_apply_next_architecture(model) get_and_apply_next_architecture(model)
self.iterative_sample_and_forward(model) self.iterative_sample_and_forward(model)
def test_proxylessnas(self):
model = self.model_module.LayerChoiceOnlySearchSpace(self)
get_and_apply_next_architecture(model)
self.iterative_sample_and_forward(model)
def test_layer_choice(self):
for i in range(2):
for j in range(2):
if j == 0:
# test number
layer_choice = LayerChoice([nn.Conv2d(3, 3, 3), nn.Conv2d(3, 5, 3), nn.Conv2d(3, 6, 3)])
else:
# test ordered dict
layer_choice = LayerChoice(OrderedDict([
("conv1", nn.Conv2d(3, 3, 3)),
("conv2", nn.Conv2d(3, 5, 3)),
("conv3", nn.Conv2d(3, 6, 3))
]))
if i == 0:
# test modify
self.assertEqual(len(layer_choice.choices), 3)
layer_choice[1] = nn.Conv2d(3, 4, 3)
self.assertEqual(layer_choice[1].out_channels, 4)
self.assertEqual(len(layer_choice[0:2]), 2)
if j > 0:
layer_choice["conv3"] = nn.Conv2d(3, 7, 3)
self.assertEqual(layer_choice[-1].out_channels, 7)
if i == 1:
# test delete
del layer_choice[1]
self.assertEqual(len(layer_choice), 2)
self.assertEqual(len(list(layer_choice)), 2)
self.assertEqual(layer_choice.names, ["conv1", "conv3"] if j > 0 else ["0", "2"])
if j > 0:
del layer_choice["conv1"]
self.assertEqual(len(layer_choice), 1)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment