Unverified Commit 6cb5916f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Support OrderedDict for LayerChoice (#2336)

parent 319ff036
...@@ -8,6 +8,8 @@ https://github.com/pytorch/examples/blob/master/mnist/main.py ...@@ -8,6 +8,8 @@ https://github.com/pytorch/examples/blob/master/mnist/main.py
import os import os
import argparse import argparse
import logging import logging
from collections import OrderedDict
import nni import nni
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -26,13 +28,15 @@ class Net(nn.Module): ...@@ -26,13 +28,15 @@ class Net(nn.Module):
def __init__(self, hidden_size): def __init__(self, hidden_size):
super(Net, self).__init__() super(Net, self).__init__()
# two options of conv1 # two options of conv1
self.conv1 = LayerChoice([nn.Conv2d(1, 20, 5, 1), self.conv1 = LayerChoice(OrderedDict([
nn.Conv2d(1, 20, 3, 1)], ("conv5x5", nn.Conv2d(1, 20, 5, 1)),
key='first_conv') ("conv3x3", nn.Conv2d(1, 20, 3, 1))
]), key='first_conv')
# two options of mid_conv # two options of mid_conv
self.mid_conv = LayerChoice([nn.Conv2d(20, 20, 3, 1, padding=1), self.mid_conv = LayerChoice([
nn.Conv2d(20, 20, 5, 1, padding=2)], nn.Conv2d(20, 20, 3, 1, padding=1),
key='mid_conv') nn.Conv2d(20, 20, 5, 1, padding=2)
], key='mid_conv')
self.conv2 = nn.Conv2d(20, 50, 5, 1) self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, hidden_size) self.fc1 = nn.Linear(4*4*50, hidden_size)
self.fc2 = nn.Linear(hidden_size, 10) self.fc2 = nn.Linear(hidden_size, 10)
...@@ -167,7 +171,6 @@ def get_params(): ...@@ -167,7 +171,6 @@ def get_params():
parser.add_argument('--log_interval', type=int, default=1000, metavar='N', parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
help='how many batches to wait before logging training status') help='how many batches to wait before logging training status')
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
return args return args
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -43,17 +45,15 @@ class Node(nn.Module): ...@@ -43,17 +45,15 @@ class Node(nn.Module):
stride = 2 if i < num_downsample_connect else 1 stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i)) choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append( self.ops.append(
mutables.LayerChoice( mutables.LayerChoice(OrderedDict([
[ ("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
ops.PoolBN('max', channels, 3, stride, 1, affine=False), ("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False), ("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)),
nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False), ("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)),
ops.SepConv(channels, channels, 3, stride, 1, affine=False), ("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)),
ops.SepConv(channels, channels, 5, stride, 2, affine=False), ("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False), ("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False))
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False) ]), key=choice_keys[-1]))
],
key=choice_keys[-1]))
self.drop_path = ops.DropPath() self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
......
...@@ -151,6 +151,5 @@ def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"): ...@@ -151,6 +151,5 @@ def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"):
for k, v in checkpoint["state_dict"].items(): for k, v in checkpoint["state_dict"].items():
if k.startswith("module."): if k.startswith("module."):
k = k[len("module."):] k = k[len("module."):]
k = re.sub(r"^(features.\d+).(\d+)", "\\1.choices.\\2", k)
result[k] = v result[k] = v
return result return result
...@@ -203,7 +203,7 @@ class ClassicMutator(Mutator): ...@@ -203,7 +203,7 @@ class ClassicMutator(Mutator):
# for now we only generate flattened search space # for now we only generate flattened search space
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
key = mutable.key key = mutable.key
val = [repr(choice) for choice in mutable.choices] val = mutable.names
search_space[key] = {"_type": LAYER_CHOICE, "_value": val} search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
elif isinstance(mutable, InputChoice): elif isinstance(mutable, InputChoice):
key = mutable.key key = mutable.key
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
from collections import OrderedDict
import torch.nn as nn import torch.nn as nn
...@@ -83,9 +84,6 @@ class Mutable(nn.Module): ...@@ -83,9 +84,6 @@ class Mutable(nn.Module):
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` " "Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self)) "so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return "{} ({})".format(self.name, self.key)
class MutableScope(Mutable): class MutableScope(Mutable):
""" """
...@@ -128,7 +126,7 @@ class LayerChoice(Mutable): ...@@ -128,7 +126,7 @@ class LayerChoice(Mutable):
Parameters Parameters
---------- ----------
op_candidates : list of nn.Module op_candidates : list of nn.Module or OrderedDict
A module list to be selected from. A module list to be selected from.
reduction : str reduction : str
``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected. ``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected.
...@@ -143,12 +141,42 @@ class LayerChoice(Mutable): ...@@ -143,12 +141,42 @@ class LayerChoice(Mutable):
---------- ----------
length : int length : int
Number of ops to choose from. Number of ops to choose from.
names: list of str
Names of candidates.
Notes
-----
``op_candidates`` can be a list of modules or a ordered dict of named modules, for example,
.. code-block:: python
self.op_choice = LayerChoice(OrderedDict([
("conv3x3", nn.Conv2d(3, 16, 128)),
("conv5x5", nn.Conv2d(5, 16, 128)),
("conv7x7", nn.Conv2d(7, 16, 128))
]))
""" """
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None): def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
super().__init__(key=key) super().__init__(key=key)
self.length = len(op_candidates) self.length = len(op_candidates)
self.choices = nn.ModuleList(op_candidates) self.choices = []
self.names = []
if isinstance(op_candidates, OrderedDict):
for name, module in op_candidates.items():
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.add_module(name, module)
self.choices.append(module)
self.names.append(name)
elif isinstance(op_candidates, list):
for i, module in enumerate(op_candidates):
self.add_module(str(i), module)
self.choices.append(module)
self.names.append(str(i))
else:
raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates)))
self.reduction = reduction self.reduction = reduction
self.return_mask = return_mask self.return_mask = return_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment