Unverified Commit 6cb5916f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Support OrderedDict for LayerChoice (#2336)

parent 319ff036
......@@ -8,6 +8,8 @@ https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import argparse
import logging
from collections import OrderedDict
import nni
import torch
import torch.nn as nn
......@@ -26,13 +28,15 @@ class Net(nn.Module):
def __init__(self, hidden_size):
super(Net, self).__init__()
# two options of conv1
self.conv1 = LayerChoice([nn.Conv2d(1, 20, 5, 1),
nn.Conv2d(1, 20, 3, 1)],
key='first_conv')
self.conv1 = LayerChoice(OrderedDict([
("conv5x5", nn.Conv2d(1, 20, 5, 1)),
("conv3x3", nn.Conv2d(1, 20, 3, 1))
]), key='first_conv')
# two options of mid_conv
self.mid_conv = LayerChoice([nn.Conv2d(20, 20, 3, 1, padding=1),
nn.Conv2d(20, 20, 5, 1, padding=2)],
key='mid_conv')
self.mid_conv = LayerChoice([
nn.Conv2d(20, 20, 3, 1, padding=1),
nn.Conv2d(20, 20, 5, 1, padding=2)
], key='mid_conv')
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, hidden_size)
self.fc2 = nn.Linear(hidden_size, 10)
......@@ -167,7 +171,6 @@ def get_params():
parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
help='how many batches to wait before logging training status')
args, _ = parser.parse_known_args()
return args
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import OrderedDict
import torch
import torch.nn as nn
......@@ -43,17 +45,15 @@ class Node(nn.Module):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
mutables.LayerChoice(
[
ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
],
key=choice_keys[-1]))
mutables.LayerChoice(OrderedDict([
("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)),
("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)),
("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)),
("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False))
]), key=choice_keys[-1]))
self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
......
......@@ -151,6 +151,5 @@ def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"):
for k, v in checkpoint["state_dict"].items():
if k.startswith("module."):
k = k[len("module."):]
k = re.sub(r"^(features.\d+).(\d+)", "\\1.choices.\\2", k)
result[k] = v
return result
......@@ -203,7 +203,7 @@ class ClassicMutator(Mutator):
# for now we only generate flattened search space
if isinstance(mutable, LayerChoice):
key = mutable.key
val = [repr(choice) for choice in mutable.choices]
val = mutable.names
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
elif isinstance(mutable, InputChoice):
key = mutable.key
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import logging
from collections import OrderedDict
import torch.nn as nn
......@@ -83,9 +84,6 @@ class Mutable(nn.Module):
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return "{} ({})".format(self.name, self.key)
class MutableScope(Mutable):
"""
......@@ -128,7 +126,7 @@ class LayerChoice(Mutable):
Parameters
----------
op_candidates : list of nn.Module
op_candidates : list of nn.Module or OrderedDict
A module list to be selected from.
reduction : str
``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected.
......@@ -143,12 +141,42 @@ class LayerChoice(Mutable):
----------
length : int
Number of ops to choose from.
names: list of str
Names of candidates.
Notes
-----
``op_candidates`` can be a list of modules or a ordered dict of named modules, for example,
.. code-block:: python
self.op_choice = LayerChoice(OrderedDict([
("conv3x3", nn.Conv2d(3, 16, 128)),
("conv5x5", nn.Conv2d(5, 16, 128)),
("conv7x7", nn.Conv2d(7, 16, 128))
]))
"""
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
super().__init__(key=key)
self.length = len(op_candidates)
self.choices = nn.ModuleList(op_candidates)
self.choices = []
self.names = []
if isinstance(op_candidates, OrderedDict):
for name, module in op_candidates.items():
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.add_module(name, module)
self.choices.append(module)
self.names.append(name)
elif isinstance(op_candidates, list):
for i, module in enumerate(op_candidates):
self.add_module(str(i), module)
self.choices.append(module)
self.names.append(str(i))
else:
raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates)))
self.reduction = reduction
self.return_mask = return_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment