Unverified Commit aa316742 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #233 from microsoft/master

merge master
parents 3fe117f0 24fa4619
...@@ -19,7 +19,19 @@ class Mutable(nn.Module): ...@@ -19,7 +19,19 @@ class Mutable(nn.Module):
decisions among different mutables. In mutator's implementation, mutators should use the key to decisions among different mutables. In mutator's implementation, mutators should use the key to
distinguish different mutables. Mutables that share the same key should be "similar" to each other. distinguish different mutables. Mutables that share the same key should be "similar" to each other.
Currently the default scope for keys is global. Currently the default scope for keys is global. By default, the keys uses a global counter from 1 to
produce unique ids.
Parameters
----------
key : str
The key of mutable.
Notes
-----
The counter is program level, but mutables are model level. In case multiple models are defined, and
you want to have `counter` starting from 1 in the second model, it's recommended to assign keys manually
instead of using automatic keys.
""" """
def __init__(self, key=None): def __init__(self, key=None):
...@@ -51,10 +63,16 @@ class Mutable(nn.Module): ...@@ -51,10 +63,16 @@ class Mutable(nn.Module):
@property @property
def key(self): def key(self):
"""
Read-only property of key.
"""
return self._key return self._key
@property @property
def name(self): def name(self):
"""
After the search space is parsed, it will be the module name of the mutable.
"""
return self._name if hasattr(self, "_name") else "_key" return self._name if hasattr(self, "_name") else "_key"
@name.setter @name.setter
...@@ -75,11 +93,23 @@ class Mutable(nn.Module): ...@@ -75,11 +93,23 @@ class Mutable(nn.Module):
class MutableScope(Mutable): class MutableScope(Mutable):
""" """
Mutable scope marks a subgraph/submodule to help mutators make better decisions. Mutable scope marks a subgraph/submodule to help mutators make better decisions.
Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope``
and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update.
MutableScope are also mutables that are listed in the mutables (search space).
"""
If not annotated with mutable scope, search space will be flattened as a list. However, some mutators might
need to leverage the concept of a "cell". So if a module is defined as a mutable scope, everything in it will
look like "sub-search-space" in the scope. Scopes can be nested.
There are two ways mutators can use mutable scope. One is to traverse the search space as a tree during initialization
and reset. The other is to implement `enter_mutable_scope` and `exit_mutable_scope`. They are called before and after
the forward method of the class inheriting mutable scope.
Mutable scopes are also mutables that are listed in the mutator.mutables (search space), but they are not supposed
to appear in the dict of choices.
Parameters
----------
key : str
Key of mutable scope.
"""
def __init__(self, key): def __init__(self, key):
super().__init__(key=key) super().__init__(key=key)
...@@ -93,6 +123,31 @@ class MutableScope(Mutable): ...@@ -93,6 +123,31 @@ class MutableScope(Mutable):
class LayerChoice(Mutable): class LayerChoice(Mutable):
"""
Layer choice selects one of the ``op_candidates``, then apply it on inputs and return results.
In rare cases, it can also select zero or many.
Layer choice does not allow itself to be nested.
Parameters
----------
op_candidates : list of nn.Module
A module list to be selected from.
reduction : str
``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected.
If ``none``, a list is returned. ``mean`` returns the average. ``sum`` returns the sum.
``concat`` concatenate the list at dimension 1.
return_mask : bool
If ``return_mask``, return output tensor and a mask. Otherwise return tensor only.
key : str
Key of the input choice.
Attributes
----------
length : int
Number of ops to choose from.
"""
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None): def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
super().__init__(key=key) super().__init__(key=key)
self.length = len(op_candidates) self.length = len(op_candidates)
...@@ -101,6 +156,12 @@ class LayerChoice(Mutable): ...@@ -101,6 +156,12 @@ class LayerChoice(Mutable):
self.return_mask = return_mask self.return_mask = return_mask
def forward(self, *inputs): def forward(self, *inputs):
"""
Returns
-------
tuple of tensors
Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
"""
out, mask = self.mutator.on_forward_layer_choice(self, *inputs) out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask: if self.return_mask:
return out, mask return out, mask
...@@ -109,42 +170,62 @@ class LayerChoice(Mutable): ...@@ -109,42 +170,62 @@ class LayerChoice(Mutable):
class InputChoice(Mutable): class InputChoice(Mutable):
""" """
Input choice selects `n_chosen` inputs from `choose_from` (contains `n_candidates` keys). For beginners, Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys). For beginners,
use `n_candidates` instead of `choose_from` is a safe option. To get the most power out of it, you might want to use ``n_candidates`` instead of ``choose_from`` is a safe option. To get the most power out of it, you might want to
know about `choose_from`. know about ``choose_from``.
The keys in `choose_from` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones. The keys in ``choose_from`` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones.
The keys are designed to be the keys of the sources. To help mutators make better decisions, The keys are designed to be the keys of the sources. To help mutators make better decisions,
mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the
output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g., output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g.,
``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a ``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a
module/submodule, it needs to be annotated with a key: that's where a ``MutableScope`` is needed. module/submodule, it needs to be annotated with a key: that's where a :class:`MutableScope` is needed.
In the example below, ``input_choice`` is a 4-choose-any. The first 3 is semantically output of cell1, output of cell2,
output of cell3 with respectively. Notice that an extra max pooling is followed by cell1, indicating x1 is not
"actually" the direct output of cell1.
.. code-block:: python
class Cell(MutableScope):
pass
class Net(nn.Module):
def __init__(self):
self.cell1 = Cell("cell1")
self.cell2 = Cell("cell2")
self.op = LayerChoice([conv3x3(), conv5x5()], key="op")
self.input_choice = InputChoice(choose_from=["cell1", "cell2", "op", InputChoice.NO_KEY])
def forward(self, x):
x1 = max_pooling(self.cell1(x))
x2 = self.cell2(x)
x3 = self.op(x)
x4 = torch.zeros_like(x)
return self.input_choice([x1, x2, x3, x4])
Parameters
----------
n_candidates : int
Number of inputs to choose from.
choose_from : list of str
List of source keys to choose from. At least of one of ``choose_from`` and ``n_candidates`` must be fulfilled.
If ``n_candidates`` has a value but ``choose_from`` is None, it will be automatically treated as ``n_candidates``
number of empty string.
n_chosen : int
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction : str
``mean``, ``concat``, ``sum`` or ``none``. See :class:`LayerChoice`.
return_mask : bool
If ``return_mask``, return output tensor and a mask. Otherwise return tensor only.
key : str
Key of the input choice.
""" """
NO_KEY = "" NO_KEY = ""
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, def __init__(self, n_candidates=None, choose_from=None, n_chosen=None,
reduction="sum", return_mask=False, key=None): reduction="sum", return_mask=False, key=None):
"""
Initialization.
Parameters
----------
n_candidates : int
Number of inputs to choose from.
choose_from : list of str
List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled.
If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates`
number of empty string.
n_chosen : int
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction : str
`mean`, `concat`, `sum` or `none`.
return_mask : bool
If `return_mask`, return output tensor and a mask. Otherwise return tensor only.
key : str
Key of the input choice.
"""
super().__init__(key=key) super().__init__(key=key)
# precondition check # precondition check
assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \ assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \
...@@ -172,12 +253,13 @@ class InputChoice(Mutable): ...@@ -172,12 +253,13 @@ class InputChoice(Mutable):
---------- ----------
optional_inputs : list or dict optional_inputs : list or dict
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
`choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as ``choose_from`` in initialization. As a list, inputs must follow the semantic order that is the same as
`choose_from`. ``choose_from``.
Returns Returns
------- -------
tuple of torch.Tensor and torch.Tensor or torch.Tensor tuple of tensors
Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
""" """
optional_input_list = optional_inputs optional_input_list = optional_inputs
if isinstance(optional_inputs, dict): if isinstance(optional_inputs, dict):
......
...@@ -43,10 +43,6 @@ class Mutator(BaseMutator): ...@@ -43,10 +43,6 @@ class Mutator(BaseMutator):
""" """
Reset the mutator by call the `sample_search` to resample (for search). Stores the result in a local Reset the mutator by call the `sample_search` to resample (for search). Stores the result in a local
variable so that `on_forward_layer_choice` and `on_forward_input_choice` can use the decision directly. variable so that `on_forward_layer_choice` and `on_forward_input_choice` can use the decision directly.
Returns
-------
None
""" """
self._cache = self.sample_search() self._cache = self.sample_search()
...@@ -57,25 +53,28 @@ class Mutator(BaseMutator): ...@@ -57,25 +53,28 @@ class Mutator(BaseMutator):
Returns Returns
------- -------
dict dict
A mapping from key of mutables to decisions.
""" """
return self.sample_final() return self.sample_final()
def on_forward_layer_choice(self, mutable, *inputs): def on_forward_layer_choice(self, mutable, *inputs):
""" """
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers On default, this method retrieves the decision obtained previously, and select certain operations.
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified Only operations with non-zero weight will be executed. The results will be added to a list.
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`. Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`.
Parameters Parameters
---------- ----------
mutable : LayerChoice mutable : LayerChoice
Layer choice module.
inputs : list of torch.Tensor inputs : list of torch.Tensor
Inputs
Returns Returns
------- -------
tuple of torch.Tensor and torch.Tensor tuple of torch.Tensor and torch.Tensor
Output and mask.
""" """
def _map_fn(op, *inputs): def _map_fn(op, *inputs):
return op(*inputs) return op(*inputs)
...@@ -87,20 +86,20 @@ class Mutator(BaseMutator): ...@@ -87,20 +86,20 @@ class Mutator(BaseMutator):
def on_forward_input_choice(self, mutable, tensor_list): def on_forward_input_choice(self, mutable, tensor_list):
""" """
On default, this method calls :meth:`on_calc_input_choice_mask` with `tags` On default, this method retrieves the decision obtained previously, and select certain tensors.
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`.
the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the
mask with corresponding `mutable.key`.
Parameters Parameters
---------- ----------
mutable : InputChoice mutable : InputChoice
Input choice module.
tensor_list : list of torch.Tensor tensor_list : list of torch.Tensor
tags : list of string Tensor list to apply the decision on.
Returns Returns
------- -------
tuple of torch.Tensor and torch.Tensor tuple of torch.Tensor and torch.Tensor
Output and mask.
""" """
mask = self._get_decision(mutable) mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \ assert len(mask) == mutable.n_candidates, \
......
from .mutator import ProxylessNasMutator
from .trainer import ProxylessNasTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import torch
from torch import nn as nn
from torch.nn import functional as F
import numpy as np
from nni.nas.pytorch.base_mutator import BaseMutator
from nni.nas.pytorch.mutables import LayerChoice
from .utils import detach_variable
class ArchGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
detached_x = detach_variable(x)
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
return grad_x[0], binary_grads, None, None
class MixedOp(nn.Module):
"""
This class is to instantiate and manage info of one LayerChoice.
It includes architecture weights, binary weights, and member functions
operating the weights.
forward_mode:
forward/backward mode for LayerChoice: None, two, full, and full_v2.
For training architecture weights, we use full_v2 by default, and for training
model weights, we use None.
"""
forward_mode = None
def __init__(self, mutable):
"""
Parameters
----------
mutable : LayerChoice
A LayerChoice in user model
"""
super(MixedOp, self).__init__()
self.ap_path_alpha = nn.Parameter(torch.Tensor(mutable.length))
self.ap_path_wb = nn.Parameter(torch.Tensor(mutable.length))
self.ap_path_alpha.requires_grad = False
self.ap_path_wb.requires_grad = False
self.active_index = [0]
self.inactive_index = None
self.log_prob = None
self.current_prob_over_ops = None
self.n_choices = mutable.length
def get_ap_path_alpha(self):
return self.ap_path_alpha
def to_requires_grad(self):
self.ap_path_alpha.requires_grad = True
self.ap_path_wb.requires_grad = True
def to_disable_grad(self):
self.ap_path_alpha.requires_grad = False
self.ap_path_wb.requires_grad = False
def forward(self, mutable, x):
"""
Define forward of LayerChoice. For 'full_v2', backward is also defined.
The 'two' mode is explained in section 3.2.1 in the paper.
The 'full_v2' mode is explained in Appendix D in the paper.
Parameters
----------
mutable : LayerChoice
this layer's mutable
x : tensor
inputs of this layer, only support one input
Returns
-------
output: tensor
output of this layer
"""
if MixedOp.forward_mode == 'full' or MixedOp.forward_mode == 'two':
output = 0
for _i in self.active_index:
oi = self.candidate_ops[_i](x)
output = output + self.ap_path_wb[_i] * oi
for _i in self.inactive_index:
oi = self.candidate_ops[_i](x)
output = output + self.ap_path_wb[_i] * oi.detach()
elif MixedOp.forward_mode == 'full_v2':
def run_function(key, candidate_ops, active_id):
def forward(_x):
return candidate_ops[active_id](_x)
return forward
def backward_function(key, candidate_ops, active_id, binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(candidate_ops)):
if k != active_id:
out_k = candidate_ops[k](_x.data)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
output = ArchGradientFunction.apply(
x, self.ap_path_wb, run_function(mutable.key, mutable.choices, self.active_index[0]),
backward_function(mutable.key, mutable.choices, self.active_index[0], self.ap_path_wb))
else:
output = self.active_op(mutable)(x)
return output
@property
def probs_over_ops(self):
"""
Apply softmax on alpha to generate probability distribution
Returns
-------
pytorch tensor
probability distribution
"""
probs = F.softmax(self.ap_path_alpha, dim=0) # softmax to probability
return probs
@property
def chosen_index(self):
"""
choose the op with max prob
Returns
-------
int
index of the chosen one
numpy.float32
prob of the chosen one
"""
probs = self.probs_over_ops.data.cpu().numpy()
index = int(np.argmax(probs))
return index, probs[index]
def active_op(self, mutable):
"""
assume only one path is active
Returns
-------
PyTorch module
the chosen operation
"""
return mutable.choices[self.active_index[0]]
@property
def active_op_index(self):
"""
return active op's index, the active op is sampled
Returns
-------
int
index of the active op
"""
return self.active_index[0]
def set_chosen_op_active(self):
"""
set chosen index, active and inactive indexes
"""
chosen_idx, _ = self.chosen_index
self.active_index = [chosen_idx]
self.inactive_index = [_i for _i in range(0, chosen_idx)] + \
[_i for _i in range(chosen_idx + 1, self.n_choices)]
def binarize(self, mutable):
"""
Sample based on alpha, and set binary weights accordingly.
ap_path_wb is set in this function, which is called binarize.
Parameters
----------
mutable : LayerChoice
this layer's mutable
"""
self.log_prob = None
# reset binary gates
self.ap_path_wb.data.zero_()
probs = self.probs_over_ops
if MixedOp.forward_mode == 'two':
# sample two ops according to probs
sample_op = torch.multinomial(probs.data, 2, replacement=False)
probs_slice = F.softmax(torch.stack([
self.ap_path_alpha[idx] for idx in sample_op
]), dim=0)
self.current_prob_over_ops = torch.zeros_like(probs)
for i, idx in enumerate(sample_op):
self.current_prob_over_ops[idx] = probs_slice[i]
# choose one to be active and the other to be inactive according to probs_slice
c = torch.multinomial(probs_slice.data, 1)[0] # 0 or 1
active_op = sample_op[c].item()
inactive_op = sample_op[1-c].item()
self.active_index = [active_op]
self.inactive_index = [inactive_op]
# set binary gate
self.ap_path_wb.data[active_op] = 1.0
else:
sample = torch.multinomial(probs, 1)[0].item()
self.active_index = [sample]
self.inactive_index = [_i for _i in range(0, sample)] + \
[_i for _i in range(sample + 1, len(mutable.choices))]
self.log_prob = torch.log(probs[sample])
self.current_prob_over_ops = probs
self.ap_path_wb.data[sample] = 1.0
# avoid over-regularization
for choice in mutable.choices:
for _, param in choice.named_parameters():
param.grad = None
@staticmethod
def delta_ij(i, j):
if i == j:
return 1
else:
return 0
def set_arch_param_grad(self, mutable):
"""
Calculate alpha gradient for this LayerChoice.
It is calculated using gradient of binary gate, probs of ops.
"""
binary_grads = self.ap_path_wb.grad.data
if self.active_op(mutable).is_zero_layer():
self.ap_path_alpha.grad = None
return
if self.ap_path_alpha.grad is None:
self.ap_path_alpha.grad = torch.zeros_like(self.ap_path_alpha.data)
if MixedOp.forward_mode == 'two':
involved_idx = self.active_index + self.inactive_index
probs_slice = F.softmax(torch.stack([
self.ap_path_alpha[idx] for idx in involved_idx
]), dim=0).data
for i in range(2):
for j in range(2):
origin_i = involved_idx[i]
origin_j = involved_idx[j]
self.ap_path_alpha.grad.data[origin_i] += \
binary_grads[origin_j] * probs_slice[j] * (MixedOp.delta_ij(i, j) - probs_slice[i])
for _i, idx in enumerate(self.active_index):
self.active_index[_i] = (idx, self.ap_path_alpha.data[idx].item())
for _i, idx in enumerate(self.inactive_index):
self.inactive_index[_i] = (idx, self.ap_path_alpha.data[idx].item())
else:
probs = self.probs_over_ops.data
for i in range(self.n_choices):
for j in range(self.n_choices):
self.ap_path_alpha.grad.data[i] += binary_grads[j] * probs[j] * (MixedOp.delta_ij(i, j) - probs[i])
return
def rescale_updated_arch_param(self):
"""
rescale architecture weights for the 'two' mode.
"""
if not isinstance(self.active_index[0], tuple):
assert self.active_op.is_zero_layer()
return
involved_idx = [idx for idx, _ in (self.active_index + self.inactive_index)]
old_alphas = [alpha for _, alpha in (self.active_index + self.inactive_index)]
new_alphas = [self.ap_path_alpha.data[idx] for idx in involved_idx]
offset = math.log(
sum([math.exp(alpha) for alpha in new_alphas]) / sum([math.exp(alpha) for alpha in old_alphas])
)
for idx in involved_idx:
self.ap_path_alpha.data[idx] -= offset
class ProxylessNasMutator(BaseMutator):
"""
This mutator initializes and operates all the LayerChoices of the input model.
It is for the corresponding trainer to control the training process of LayerChoices,
coordinating with whole training process.
"""
def __init__(self, model):
"""
Init a MixedOp instance for each mutable i.e., LayerChoice.
And register the instantiated MixedOp in corresponding LayerChoice.
If does not register it in LayerChoice, DataParallel does not work then,
because architecture weights are not included in the DataParallel model.
When MixedOPs are registered, we use ```requires_grad``` to control
whether calculate gradients of architecture weights.
Parameters
----------
model : pytorch model
The model that users want to tune, it includes search space defined with nni nas apis
"""
super(ProxylessNasMutator, self).__init__(model)
self._unused_modules = None
self.mutable_list = []
for mutable in self.undedup_mutables:
self.mutable_list.append(mutable)
mutable.registered_module = MixedOp(mutable)
def on_forward_layer_choice(self, mutable, *inputs):
"""
Callback of layer choice forward. This function defines the forward
logic of the input mutable. So mutable is only interface, its real
implementation is defined in mutator.
Parameters
----------
mutable: LayerChoice
forward logic of this input mutable
inputs: list of torch.Tensor
inputs of this mutable
Returns
-------
torch.Tensor
output of this mutable, i.e., LayerChoice
int
index of the chosen op
"""
# FIXME: return mask, to be consistent with other algorithms
idx = mutable.registered_module.active_op_index
return mutable.registered_module(mutable, *inputs), idx
def reset_binary_gates(self):
"""
For each LayerChoice, binarize binary weights
based on alpha to only activate one op.
It traverses all the mutables in the model to do this.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.binarize(mutable)
def set_chosen_op_active(self):
"""
For each LayerChoice, set the op with highest alpha as the chosen op.
Usually used for validation.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.set_chosen_op_active()
def num_arch_params(self):
"""
The number of mutables, i.e., LayerChoice
Returns
-------
int
the number of LayerChoice in user model
"""
return len(self.mutable_list)
def set_arch_param_grad(self):
"""
For each LayerChoice, calculate gradients for architecture weights, i.e., alpha
"""
for mutable in self.undedup_mutables:
mutable.registered_module.set_arch_param_grad(mutable)
def get_architecture_parameters(self):
"""
Get all the architecture parameters.
yield
-----
PyTorch Parameter
Return ap_path_alpha of the traversed mutable
"""
for mutable in self.undedup_mutables:
yield mutable.registered_module.get_ap_path_alpha()
def change_forward_mode(self, mode):
"""
Update forward mode of MixedOps, as training architecture weights and
model weights use different forward modes.
"""
MixedOp.forward_mode = mode
def get_forward_mode(self):
"""
Get forward mode of MixedOp
Returns
-------
string
the current forward mode of MixedOp
"""
return MixedOp.forward_mode
def rescale_updated_arch_param(self):
"""
Rescale architecture weights in 'two' mode.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.rescale_updated_arch_param()
def unused_modules_off(self):
"""
Remove unused modules for each mutables.
The removed modules are kept in ```self._unused_modules``` for resume later.
"""
self._unused_modules = []
for mutable in self.undedup_mutables:
mixed_op = mutable.registered_module
unused = {}
if self.get_forward_mode() in ['full', 'two', 'full_v2']:
involved_index = mixed_op.active_index + mixed_op.inactive_index
else:
involved_index = mixed_op.active_index
for i in range(mixed_op.n_choices):
if i not in involved_index:
unused[i] = mutable.choices[i]
mutable.choices[i] = None
self._unused_modules.append(unused)
def unused_modules_back(self):
"""
Resume the removed modules back.
"""
if self._unused_modules is None:
return
for m, unused in zip(self.mutable_list, self._unused_modules):
for i in unused:
m.choices[i] = unused[i]
self._unused_modules = None
def arch_requires_grad(self):
"""
Make architecture weights require gradient
"""
for mutable in self.undedup_mutables:
mutable.registered_module.to_requires_grad()
def arch_disable_grad(self):
"""
Disable gradient of architecture weights, i.e., does not
calcuate gradient for them.
"""
for mutable in self.undedup_mutables:
mutable.registered_module.to_disable_grad()
def sample_final(self):
"""
Generate the final chosen architecture.
Returns
-------
dict
the choice of each mutable, i.e., LayerChoice
"""
result = dict()
for mutable in self.undedup_mutables:
assert isinstance(mutable, LayerChoice)
index, _ = mutable.registered_module.chosen_index
# pylint: disable=not-callable
result[mutable.key] = F.one_hot(torch.tensor(index), num_classes=mutable.length).view(-1).bool()
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import time
import json
import logging
import torch
from torch import nn as nn
from nni.nas.pytorch.base_trainer import BaseTrainer
from nni.nas.pytorch.trainer import TorchTensorEncoder
from nni.nas.pytorch.utils import AverageMeter
from .mutator import ProxylessNasMutator
from .utils import cross_entropy_with_label_smoothing, accuracy
logger = logging.getLogger(__name__)
class ProxylessNasTrainer(BaseTrainer):
def __init__(self, model, model_optim, device,
train_loader, valid_loader, label_smoothing=0.1,
n_epochs=120, init_lr=0.025, binary_mode='full_v2',
arch_init_type='normal', arch_init_ratio=1e-3,
arch_optim_lr=1e-3, arch_weight_decay=0,
grad_update_arch_param_every=5, grad_update_steps=1,
warmup=True, warmup_epochs=25,
arch_valid_frequency=1,
load_ckpt=False, ckpt_path=None, arch_path=None):
"""
Parameters
----------
model : pytorch model
the user model, which has mutables
model_optim : pytorch optimizer
the user defined optimizer
device : pytorch device
the devices to train/search the model
train_loader : pytorch data loader
data loader for the training set
valid_loader : pytorch data loader
data loader for the validation set
label_smoothing : float
for label smoothing
n_epochs : int
number of epochs to train/search
init_lr : float
init learning rate for training the model
binary_mode : str
the forward/backward mode for the binary weights in mutator
arch_init_type : str
the way to init architecture parameters
arch_init_ratio : float
the ratio to init architecture parameters
arch_optim_lr : float
learning rate of the architecture parameters optimizer
arch_weight_decay : float
weight decay of the architecture parameters optimizer
grad_update_arch_param_every : int
update architecture weights every this number of minibatches
grad_update_steps : int
during each update of architecture weights, the number of steps to train
warmup : bool
whether to do warmup
warmup_epochs : int
the number of epochs to do during warmup
arch_valid_frequency : int
frequency of printing validation result
load_ckpt : bool
whether load checkpoint
ckpt_path : str
checkpoint path, if load_ckpt is True, ckpt_path cannot be None
arch_path : str
the path to store chosen architecture
"""
self.model = model
self.model_optim = model_optim
self.train_loader = train_loader
self.valid_loader = valid_loader
self.device = device
self.n_epochs = n_epochs
self.init_lr = init_lr
self.warmup = warmup
self.warmup_epochs = warmup_epochs
self.arch_valid_frequency = arch_valid_frequency
self.label_smoothing = label_smoothing
self.train_batch_size = train_loader.batch_sampler.batch_size
self.valid_batch_size = valid_loader.batch_sampler.batch_size
# update architecture parameters every this number of minibatches
self.grad_update_arch_param_every = grad_update_arch_param_every
# the number of steps per architecture parameter update
self.grad_update_steps = grad_update_steps
self.binary_mode = binary_mode
self.load_ckpt = load_ckpt
self.ckpt_path = ckpt_path
self.arch_path = arch_path
# init mutator
self.mutator = ProxylessNasMutator(model)
# DataParallel should be put behind the init of mutator
self.model = torch.nn.DataParallel(self.model)
self.model.to(self.device)
# iter of valid dataset for training architecture weights
self._valid_iter = None
# init architecture weights
self._init_arch_params(arch_init_type, arch_init_ratio)
# build architecture optimizer
self.arch_optimizer = torch.optim.Adam(self.mutator.get_architecture_parameters(),
arch_optim_lr,
weight_decay=arch_weight_decay,
betas=(0, 0.999),
eps=1e-8)
self.criterion = nn.CrossEntropyLoss()
self.warmup_curr_epoch = 0
self.train_curr_epoch = 0
def _init_arch_params(self, init_type='normal', init_ratio=1e-3):
"""
Initialize architecture weights
"""
for param in self.mutator.get_architecture_parameters():
if init_type == 'normal':
param.data.normal_(0, init_ratio)
elif init_type == 'uniform':
param.data.uniform_(-init_ratio, init_ratio)
else:
raise NotImplementedError
def _validate(self):
"""
Do validation. During validation, LayerChoices use the chosen active op.
Returns
-------
float, float, float
average loss, average top1 accuracy, average top5 accuracy
"""
self.valid_loader.batch_sampler.batch_size = self.valid_batch_size
self.valid_loader.batch_sampler.drop_last = False
self.mutator.set_chosen_op_active()
# remove unused modules to save memory
self.mutator.unused_modules_off()
# test on validation set under train mode
self.model.train()
batch_time = AverageMeter('batch_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
end = time.time()
with torch.no_grad():
for i, (images, labels) in enumerate(self.valid_loader):
images, labels = images.to(self.device), labels.to(self.device)
output = self.model(images)
loss = self.criterion(output, labels)
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0 or i + 1 == len(self.valid_loader):
test_log = 'Valid' + ': [{0}/{1}]\t'\
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'\
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'\
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'.\
format(i, len(self.valid_loader) - 1, batch_time=batch_time, loss=losses, top1=top1)
# return top5:
test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(top5=top5)
logger.info(test_log)
self.mutator.unused_modules_back()
return losses.avg, top1.avg, top5.avg
def _warm_up(self):
"""
Warm up the model, during warm up, architecture weights are not trained.
"""
lr_max = 0.05
data_loader = self.train_loader
nBatch = len(data_loader)
T_total = self.warmup_epochs * nBatch # total num of batches
for epoch in range(self.warmup_curr_epoch, self.warmup_epochs):
logger.info('\n--------Warmup epoch: %d--------\n', epoch + 1)
batch_time = AverageMeter('batch_time')
data_time = AverageMeter('data_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
# switch to train mode
self.model.train()
end = time.time()
logger.info('warm_up epoch: %d', epoch)
for i, (images, labels) in enumerate(data_loader):
data_time.update(time.time() - end)
# lr
T_cur = epoch * nBatch + i
warmup_lr = 0.5 * lr_max * (1 + math.cos(math.pi * T_cur / T_total))
for param_group in self.model_optim.param_groups:
param_group['lr'] = warmup_lr
images, labels = images.to(self.device), labels.to(self.device)
# compute output
self.mutator.reset_binary_gates() # random sample binary gates
self.mutator.unused_modules_off() # remove unused module for speedup
output = self.model(images)
if self.label_smoothing > 0:
loss = cross_entropy_with_label_smoothing(output, labels, self.label_smoothing)
else:
loss = self.criterion(output, labels)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# compute gradient and do SGD step
self.model.zero_grad()
loss.backward()
self.model_optim.step()
# unused modules back
self.mutator.unused_modules_back()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0 or i + 1 == nBatch:
batch_log = 'Warmup Train [{0}][{1}/{2}]\t' \
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
losses=losses, top1=top1, top5=top5, lr=warmup_lr)
logger.info(batch_log)
val_loss, val_top1, val_top5 = self._validate()
val_log = 'Warmup Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc {4:.3f}\t' \
'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}M'. \
format(epoch + 1, self.warmup_epochs, val_loss, val_top1, val_top5, top1=top1, top5=top5)
logger.info(val_log)
self.save_checkpoint()
self.warmup_curr_epoch += 1
def _get_update_schedule(self, nBatch):
"""
Generate schedule for training architecture weights. Key means after which minibatch
to update architecture weights, value means how many steps for the update.
Parameters
----------
nBatch : int
the total number of minibatches in one epoch
Returns
-------
dict
the schedule for updating architecture weights
"""
schedule = {}
for i in range(nBatch):
if (i + 1) % self.grad_update_arch_param_every == 0:
schedule[i] = self.grad_update_steps
return schedule
def _calc_learning_rate(self, epoch, batch=0, nBatch=None):
"""
Update learning rate.
"""
T_total = self.n_epochs * nBatch
T_cur = epoch * nBatch + batch
lr = 0.5 * self.init_lr * (1 + math.cos(math.pi * T_cur / T_total))
return lr
def _adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None):
"""
Adjust learning of a given optimizer and return the new learning rate
Parameters
----------
optimizer : pytorch optimizer
the used optimizer
epoch : int
the current epoch number
batch : int
the current minibatch
nBatch : int
the total number of minibatches in one epoch
Returns
-------
float
the adjusted learning rate
"""
new_lr = self._calc_learning_rate(epoch, batch, nBatch)
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
return new_lr
def _train(self):
"""
Train the model, it trains model weights and architecute weights.
Architecture weights are trained according to the schedule.
Before updating architecture weights, ```requires_grad``` is enabled.
Then, it is disabled after the updating, in order not to update
architecture weights when training model weights.
"""
nBatch = len(self.train_loader)
arch_param_num = self.mutator.num_arch_params()
binary_gates_num = self.mutator.num_arch_params()
logger.info('#arch_params: %d\t#binary_gates: %d', arch_param_num, binary_gates_num)
update_schedule = self._get_update_schedule(nBatch)
for epoch in range(self.train_curr_epoch, self.n_epochs):
logger.info('\n--------Train epoch: %d--------\n', epoch + 1)
batch_time = AverageMeter('batch_time')
data_time = AverageMeter('data_time')
losses = AverageMeter('losses')
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
# switch to train mode
self.model.train()
end = time.time()
for i, (images, labels) in enumerate(self.train_loader):
data_time.update(time.time() - end)
lr = self._adjust_learning_rate(self.model_optim, epoch, batch=i, nBatch=nBatch)
# train weight parameters
images, labels = images.to(self.device), labels.to(self.device)
self.mutator.reset_binary_gates()
self.mutator.unused_modules_off()
output = self.model(images)
if self.label_smoothing > 0:
loss = cross_entropy_with_label_smoothing(output, labels, self.label_smoothing)
else:
loss = self.criterion(output, labels)
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
losses.update(loss, images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
self.model.zero_grad()
loss.backward()
self.model_optim.step()
self.mutator.unused_modules_back()
if epoch > 0:
for _ in range(update_schedule.get(i, 0)):
start_time = time.time()
# GradientArchSearchConfig
self.mutator.arch_requires_grad()
arch_loss, exp_value = self._gradient_step()
self.mutator.arch_disable_grad()
used_time = time.time() - start_time
log_str = 'Architecture [%d-%d]\t Time %.4f\t Loss %.4f\t null %s' % \
(epoch + 1, i, used_time, arch_loss, exp_value)
logger.info(log_str)
batch_time.update(time.time() - end)
end = time.time()
# training log
if i % 10 == 0 or i + 1 == nBatch:
batch_log = 'Train [{0}][{1}/{2}]\t' \
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t' \
'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
losses=losses, top1=top1, top5=top5, lr=lr)
logger.info(batch_log)
# validate
if (epoch + 1) % self.arch_valid_frequency == 0:
val_loss, val_top1, val_top5 = self._validate()
val_log = 'Valid [{0}]\tloss {1:.3f}\ttop-1 acc {2:.3f} \ttop-5 acc {3:.3f}\t' \
'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}'. \
format(epoch + 1, val_loss, val_top1, val_top5, top1=top1, top5=top5)
logger.info(val_log)
self.save_checkpoint()
self.train_curr_epoch += 1
def _valid_next_batch(self):
"""
Get next one minibatch from validation set
Returns
-------
(tensor, tensor)
the tuple of images and labels
"""
if self._valid_iter is None:
self._valid_iter = iter(self.valid_loader)
try:
data = next(self._valid_iter)
except StopIteration:
self._valid_iter = iter(self.valid_loader)
data = next(self._valid_iter)
return data
def _gradient_step(self):
"""
This gradient step is for updating architecture weights.
Mutator is intensively used in this function to operate on
architecture weights.
Returns
-------
float, None
loss of the model, None
"""
# use the same batch size as train batch size for architecture weights
self.valid_loader.batch_sampler.batch_size = self.train_batch_size
self.valid_loader.batch_sampler.drop_last = True
self.model.train()
self.mutator.change_forward_mode(self.binary_mode)
time1 = time.time() # time
# sample a batch of data from validation set
images, labels = self._valid_next_batch()
images, labels = images.to(self.device), labels.to(self.device)
time2 = time.time() # time
self.mutator.reset_binary_gates()
self.mutator.unused_modules_off()
output = self.model(images)
time3 = time.time()
ce_loss = self.criterion(output, labels)
expected_value = None
loss = ce_loss
self.model.zero_grad()
loss.backward()
self.mutator.set_arch_param_grad()
self.arch_optimizer.step()
if self.mutator.get_forward_mode() == 'two':
self.mutator.rescale_updated_arch_param()
self.mutator.unused_modules_back()
self.mutator.change_forward_mode(None)
time4 = time.time()
logger.info('(%.4f, %.4f, %.4f)', time2 - time1, time3 - time2, time4 - time3)
return loss.data.item(), expected_value.item() if expected_value is not None else None
def save_checkpoint(self):
"""
Save checkpoint of the whole model. Saving model weights and architecture weights in
```ckpt_path```, and saving currently chosen architecture in ```arch_path```.
"""
if self.ckpt_path:
state = {
'warmup_curr_epoch': self.warmup_curr_epoch,
'train_curr_epoch': self.train_curr_epoch,
'model': self.model.state_dict(),
'optim': self.model_optim.state_dict(),
'arch_optim': self.arch_optimizer.state_dict()
}
torch.save(state, self.ckpt_path)
if self.arch_path:
self.export(self.arch_path)
def load_checkpoint(self):
"""
Load the checkpoint from ```ckpt_path```.
"""
assert self.ckpt_path is not None, "If load_ckpt is not None, ckpt_path should not be None"
ckpt = torch.load(self.ckpt_path)
self.warmup_curr_epoch = ckpt['warmup_curr_epoch']
self.train_curr_epoch = ckpt['train_curr_epoch']
self.model.load_state_dict(ckpt['model'])
self.model_optim.load_state_dict(ckpt['optim'])
self.arch_optimizer.load_state_dict(ckpt['arch_optim'])
def train(self):
"""
Train the whole model.
"""
if self.load_ckpt:
self.load_checkpoint()
if self.warmup:
self._warm_up()
self._train()
def export(self, file_name):
"""
Export the chosen architecture into a file
Parameters
----------
file_name : str
the file that stores exported chosen architecture
"""
exported_arch = self.mutator.sample_final()
with open(file_name, 'w') as f:
json.dump(exported_arch, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def validate(self):
raise NotImplementedError
def checkpoint(self):
raise NotImplementedError
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
def detach_variable(inputs):
"""
Detach variables
Parameters
----------
inputs : pytorch tensors
pytorch tensors
"""
if isinstance(inputs, tuple):
return tuple([detach_variable(x) for x in inputs])
else:
x = inputs.detach()
x.requires_grad = inputs.requires_grad
return x
def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
"""
Parameters
----------
pred : pytorch tensor
predicted value
target : pytorch tensor
label
label_smoothing : float
the degree of label smoothing
Returns
-------
pytorch tensor
cross entropy
"""
logsoftmax = nn.LogSoftmax()
n_classes = pred.size(1)
# convert to one-hot
target = torch.unsqueeze(target, 1)
soft_target = torch.zeros_like(pred)
soft_target.scatter_(1, target, 1)
# label smoothing
soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
def accuracy(output, target, topk=(1,)):
"""
Computes the precision@k for the specified values of k
Parameters
----------
output : pytorch tensor
output, e.g., predicted value
target : pytorch tensor
label
topk : tuple
specify top1 and top5
Returns
-------
list
accuracy of top1 and top5
"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
...@@ -6,7 +6,15 @@ from nni.nas.pytorch.mutables import LayerChoice, InputChoice ...@@ -6,7 +6,15 @@ from nni.nas.pytorch.mutables import LayerChoice, InputChoice
class RandomMutator(Mutator): class RandomMutator(Mutator):
"""
Random mutator that samples a random candidate in the search space each time ``reset()``.
It uses random function in PyTorch, so users can set seed in PyTorch to ensure deterministic behavior.
"""
def sample_search(self): def sample_search(self):
"""
Sample a random candidate.
"""
result = dict() result = dict()
for mutable in self.mutables: for mutable in self.mutables:
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
...@@ -22,4 +30,7 @@ class RandomMutator(Mutator): ...@@ -22,4 +30,7 @@ class RandomMutator(Mutator):
return result return result
def sample_final(self): def sample_final(self):
"""
Same as :meth:`sample_search`.
"""
return self.sample_search() return self.sample_search()
...@@ -16,29 +16,29 @@ _logger = logging.getLogger(__name__) ...@@ -16,29 +16,29 @@ _logger = logging.getLogger(__name__)
class SPOSEvolution(Tuner): class SPOSEvolution(Tuner):
"""
SPOS evolution tuner.
Parameters
----------
max_epochs : int
Maximum number of epochs to run.
num_select : int
Number of survival candidates of each epoch.
num_population : int
Number of candidates at the start of each epoch. If candidates generated by
crossover and mutation are not enough, the rest will be filled with random
candidates.
m_prob : float
The probability of mutation.
num_crossover : int
Number of candidates generated by crossover in each epoch.
num_mutation : int
Number of candidates generated by mutation in each epoch.
"""
def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1, def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1,
num_crossover=25, num_mutation=25): num_crossover=25, num_mutation=25):
"""
Initialize SPOS Evolution Tuner.
Parameters
----------
max_epochs : int
Maximum number of epochs to run.
num_select : int
Number of survival candidates of each epoch.
num_population : int
Number of candidates at the start of each epoch. If candidates generated by
crossover and mutation are not enough, the rest will be filled with random
candidates.
m_prob : float
The probability of mutation.
num_crossover : int
Number of candidates generated by crossover in each epoch.
num_mutation : int
Number of candidates generated by mutation in each epoch.
"""
assert num_population >= num_select assert num_population >= num_select
self.max_epochs = max_epochs self.max_epochs = max_epochs
self.num_select = num_select self.num_select = num_select
......
...@@ -10,27 +10,29 @@ _logger = logging.getLogger(__name__) ...@@ -10,27 +10,29 @@ _logger = logging.getLogger(__name__)
class SPOSSupernetTrainingMutator(RandomMutator): class SPOSSupernetTrainingMutator(RandomMutator):
"""
A random mutator with flops limit.
Parameters
----------
model : nn.Module
PyTorch model.
flops_func : callable
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
is None, functions related to flops will be deactivated.
flops_lb : number
Lower bound of flops.
flops_ub : number
Upper bound of flops.
flops_bin_num : number
Number of bins divided for the interval of flops to ensure the uniformity. Bigger number will be more
uniform, but the sampling will be slower.
flops_sample_timeout : int
Maximum number of attempts to sample before giving up and use a random candidate.
"""
def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None, def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None,
flops_bin_num=7, flops_sample_timeout=500): flops_bin_num=7, flops_sample_timeout=500):
"""
Parameters
----------
model : nn.Module
PyTorch model.
flops_func : callable
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
is None, functions related to flops will be deactivated.
flops_lb : number
Lower bound of flops.
flops_ub : number
Upper bound of flops.
flops_bin_num : number
Number of bins divided for the interval of flops to ensure the uniformity. Bigger number will be more
uniform, but the sampling will be slower.
flops_sample_timeout : int
Maximum number of attempts to sample before giving up and use a random candidate.
"""
super().__init__(model) super().__init__(model)
self._flops_func = flops_func self._flops_func = flops_func
if self._flops_func is not None: if self._flops_func is not None:
......
...@@ -15,43 +15,42 @@ logger = logging.getLogger(__name__) ...@@ -15,43 +15,42 @@ logger = logging.getLogger(__name__)
class SPOSSupernetTrainer(Trainer): class SPOSSupernetTrainer(Trainer):
""" """
This trainer trains a supernet that can be used for evolution search. This trainer trains a supernet that can be used for evolution search.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : Mutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterable
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
dataset_valid : iterable
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
batch_size : int
Batch size.
workers: int
Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
""" """
def __init__(self, model, loss, metrics, def __init__(self, model, loss, metrics,
optimizer, num_epochs, train_loader, valid_loader, optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None): callbacks=None):
"""
Parameters
----------
model : nn.Module
Model with mutables.
mutator : Mutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterable
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
dataset_valid : iterable
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
batch_size : int
Batch size.
workers: int
Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
assert torch.cuda.is_available() assert torch.cuda.is_available()
super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model), super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model),
loss, metrics, optimizer, num_epochs, None, None, loss, metrics, optimizer, num_epochs, None, None,
......
...@@ -24,42 +24,54 @@ class TorchTensorEncoder(json.JSONEncoder): ...@@ -24,42 +24,54 @@ class TorchTensorEncoder(json.JSONEncoder):
class Trainer(BaseTrainer): class Trainer(BaseTrainer):
"""
A trainer with some helper functions implemented. To implement a new trainer,
users need to implement :meth:`train_one_epoch`, :meth:`validate_one_epoch` and :meth:`checkpoint`.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : BaseMutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
See `PyTorch loss functions`_ for examples.
metrics : callable
Called with logits and targets. Returns a dict that maps metrics keys to metrics data. For example,
.. code-block:: python
def metrics_fn(output, target):
return {"acc1": accuracy(output, target, topk=1), "acc5": accuracy(output, target, topk=5)}
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
dataset_train : torch.utils.data.Dataset
Dataset of training. If not otherwise specified, ``dataset_train`` and ``dataset_valid`` should be standard
PyTorch Dataset. See `torch.utils.data`_ for examples.
dataset_valid : torch.utils.data.Dataset
Dataset of validation/testing.
batch_size : int
Batch size.
workers : int
Number of workers used in data preprocessing.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
.. _`PyTorch loss functions`: https://pytorch.org/docs/stable/nn.html#loss-functions
.. _`torch.utils.data`: https://pytorch.org/docs/stable/data.html
"""
def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs, def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks): dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
"""
Trainer initialization.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : BaseMutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
dataset_train : torch.utils.data.Dataset
Dataset of training.
dataset_valid : torch.utils.data.Dataset
Dataset of validation/testing.
batch_size : int
Batch size.
workers : int
Number of workers used in data preprocessing.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = model self.model = model
self.mutator = mutator self.mutator = mutator
...@@ -84,13 +96,38 @@ class Trainer(BaseTrainer): ...@@ -84,13 +96,38 @@ class Trainer(BaseTrainer):
@abstractmethod @abstractmethod
def train_one_epoch(self, epoch): def train_one_epoch(self, epoch):
"""
Train one epoch.
Parameters
----------
epoch : int
Epoch number starting from 0.
"""
pass pass
@abstractmethod @abstractmethod
def validate_one_epoch(self, epoch): def validate_one_epoch(self, epoch):
"""
Validate one epoch.
Parameters
----------
epoch : int
Epoch number starting from 0.
"""
pass pass
def train(self, validate=True): def train(self, validate=True):
"""
Train ``num_epochs``.
Trigger callbacks at the start and the end of each epoch.
Parameters
----------
validate : bool
If ``true``, will do validation every epoch.
"""
for epoch in range(self.num_epochs): for epoch in range(self.num_epochs):
for callback in self.callbacks: for callback in self.callbacks:
callback.on_epoch_begin(epoch) callback.on_epoch_begin(epoch)
...@@ -108,12 +145,26 @@ class Trainer(BaseTrainer): ...@@ -108,12 +145,26 @@ class Trainer(BaseTrainer):
callback.on_epoch_end(epoch) callback.on_epoch_end(epoch)
def validate(self): def validate(self):
"""
Do one validation.
"""
self.validate_one_epoch(-1) self.validate_one_epoch(-1)
def export(self, file): def export(self, file):
"""
Call ``mutator.export()`` and dump the architecture to ``file``.
Parameters
----------
file : str
A file path. Expected to be a JSON.
"""
mutator_export = self.mutator.export() mutator_export = self.mutator.export()
with open(file, "w") as f: with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder) json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def checkpoint(self): def checkpoint(self):
"""
Return trainer checkpoint.
"""
raise NotImplementedError("Not implemented yet") raise NotImplementedError("Not implemented yet")
...@@ -12,12 +12,26 @@ _logger = logging.getLogger(__name__) ...@@ -12,12 +12,26 @@ _logger = logging.getLogger(__name__)
def global_mutable_counting(): def global_mutable_counting():
"""
A program level counter starting from 1.
"""
global _counter global _counter
_counter += 1 _counter += 1
return _counter return _counter
def _reset_global_mutable_counting():
"""
Reset the global mutable counting to count from 1. Useful when defining multiple models with default keys.
"""
global _counter
_counter = 0
def to_device(obj, device): def to_device(obj, device):
"""
Move a tensor, tuple, list, or dict onto device.
"""
if torch.is_tensor(obj): if torch.is_tensor(obj):
return obj.to(device) return obj.to(device)
if isinstance(obj, tuple): if isinstance(obj, tuple):
...@@ -32,12 +46,18 @@ def to_device(obj, device): ...@@ -32,12 +46,18 @@ def to_device(obj, device):
class AverageMeterGroup: class AverageMeterGroup:
"""Average meter group for multiple average meters""" """
Average meter group for multiple average meters.
"""
def __init__(self): def __init__(self):
self.meters = OrderedDict() self.meters = OrderedDict()
def update(self, data): def update(self, data):
"""
Update the meter group with a dict of metrics.
Non-exist average meters will be automatically created.
"""
for k, v in data.items(): for k, v in data.items():
if k not in self.meters: if k not in self.meters:
self.meters[k] = AverageMeter(k, ":4f") self.meters[k] = AverageMeter(k, ":4f")
...@@ -53,36 +73,49 @@ class AverageMeterGroup: ...@@ -53,36 +73,49 @@ class AverageMeterGroup:
return " ".join(str(v) for v in self.meters.values()) return " ".join(str(v) for v in self.meters.values())
def summary(self): def summary(self):
"""
Return a summary string of group data.
"""
return " ".join(v.summary() for v in self.meters.values()) return " ".join(v.summary() for v in self.meters.values())
class AverageMeter: class AverageMeter:
"""Computes and stores the average and current value""" """
Computes and stores the average and current value.
Parameters
----------
name : str
Name to display.
fmt : str
Format string to print the values.
"""
def __init__(self, name, fmt=':f'): def __init__(self, name, fmt=':f'):
"""
Initialization of AverageMeter
Parameters
----------
name : str
Name to display.
fmt : str
Format string to print the values.
"""
self.name = name self.name = name
self.fmt = fmt self.fmt = fmt
self.reset() self.reset()
def reset(self): def reset(self):
"""
Reset the meter.
"""
self.val = 0 self.val = 0
self.avg = 0 self.avg = 0
self.sum = 0 self.sum = 0
self.count = 0 self.count = 0
def update(self, val, n=1): def update(self, val, n=1):
if not isinstance(val, float) and not isinstance(val, int): """
_logger.warning("Values passed to AverageMeter must be number, not %s.", type(val)) Update with value and weight.
Parameters
----------
val : float or int
The new value to be accounted in.
n : int
The weight of the new value.
"""
self.val = val self.val = val
self.sum += val * n self.sum += val * n
self.count += n self.count += n
...@@ -104,6 +137,11 @@ class StructuredMutableTreeNode: ...@@ -104,6 +137,11 @@ class StructuredMutableTreeNode:
This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet, This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet,
the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a
``Mutable`` (other than ``MutableScope``). ``Mutable`` (other than ``MutableScope``).
Parameters
----------
mutable : nni.nas.pytorch.mutables.Mutable
The mutable that current node is linked with.
""" """
def __init__(self, mutable): def __init__(self, mutable):
...@@ -111,10 +149,16 @@ class StructuredMutableTreeNode: ...@@ -111,10 +149,16 @@ class StructuredMutableTreeNode:
self.children = [] self.children = []
def add_child(self, mutable): def add_child(self, mutable):
"""
Add a tree node to the children list of current node.
"""
self.children.append(StructuredMutableTreeNode(mutable)) self.children.append(StructuredMutableTreeNode(mutable))
return self.children[-1] return self.children[-1]
def type(self): def type(self):
"""
Return the ``type`` of mutable content.
"""
return type(self.mutable) return type(self.mutable)
def __iter__(self): def __iter__(self):
......
...@@ -33,4 +33,7 @@ def init_params(params): ...@@ -33,4 +33,7 @@ def init_params(params):
_params = copy.deepcopy(params) _params = copy.deepcopy(params)
def get_last_metric(): def get_last_metric():
return json_tricks.loads(_last_metric) metrics = json_tricks.loads(_last_metric)
metrics['value'] = json_tricks.loads(metrics['value'])
return metrics
...@@ -114,7 +114,7 @@ def report_intermediate_result(metric): ...@@ -114,7 +114,7 @@ def report_intermediate_result(metric):
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'PERIODICAL', 'type': 'PERIODICAL',
'sequence': _intermediate_seq, 'sequence': _intermediate_seq,
'value': metric 'value': to_json(metric)
}) })
_intermediate_seq += 1 _intermediate_seq += 1
platform.send_metric(metric) platform.send_metric(metric)
...@@ -135,6 +135,6 @@ def report_final_result(metric): ...@@ -135,6 +135,6 @@ def report_final_result(metric):
'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'FINAL', 'type': 'FINAL',
'sequence': 0, 'sequence': 0,
'value': metric 'value': to_json(metric)
}) })
platform.send_metric(metric) platform.send_metric(metric)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutable_scope import SpaceWithMutableScope
from .naive import NaiveSearchSpace
from .nested import NestedSpace
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
class Cell(MutableScope):
def __init__(self, cell_name, prev_labels, channels):
super().__init__(cell_name)
self.input_choice = InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True,
key=cell_name + "_input")
self.op_choice = LayerChoice([
nn.Conv2d(channels, channels, 3, padding=1),
nn.Conv2d(channels, channels, 5, padding=2),
nn.MaxPool2d(3, stride=1, padding=1),
nn.AvgPool2d(3, stride=1, padding=1),
nn.Identity()
], key=cell_name + "_op")
def forward(self, prev_layers):
chosen_input, chosen_mask = self.input_choice(prev_layers)
cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask
class Node(MutableScope):
def __init__(self, node_name, prev_node_names, channels):
super().__init__(node_name)
self.cell_x = Cell(node_name + "_x", prev_node_names, channels)
self.cell_y = Cell(node_name + "_y", prev_node_names, channels)
def forward(self, prev_layers):
out_x, mask_x = self.cell_x(prev_layers)
out_y, mask_y = self.cell_y(prev_layers)
return out_x + out_y, mask_x | mask_y
class Layer(nn.Module):
def __init__(self, num_nodes, channels):
super().__init__()
self.num_nodes = num_nodes
self.nodes = nn.ModuleList()
node_labels = [InputChoice.NO_KEY, InputChoice.NO_KEY]
for i in range(num_nodes):
node_labels.append("node_{}".format(i))
self.nodes.append(Node(node_labels[-1], node_labels[:-1], channels))
self.final_conv_w = nn.Parameter(torch.zeros(channels, self.num_nodes + 2, channels, 1, 1),
requires_grad=True)
self.bn = nn.BatchNorm2d(channels, affine=False)
def forward(self, pprev, prev):
prev_nodes_out = [pprev, prev]
nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
for i in range(self.num_nodes):
node_out, mask = self.nodes[i](prev_nodes_out)
nodes_used_mask[:mask.size(0)] |= mask.to(prev.device)
# NOTE: which device should we put mask on?
prev_nodes_out.append(node_out)
unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1)
unused_nodes = F.relu(unused_nodes)
conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :]
conv_weight = conv_weight.view(conv_weight.size(0), -1, 1, 1)
out = F.conv2d(unused_nodes, conv_weight)
return prev, self.bn(out)
class SpaceWithMutableScope(nn.Module):
def __init__(self, test_case, num_layers=4, num_nodes=5, channels=16, in_channels=3, num_classes=10):
super().__init__()
self.test_case = test_case
self.num_layers = num_layers
self.stem = nn.Sequential(
nn.Conv2d(in_channels, channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(channels)
)
self.layers = nn.ModuleList()
for _ in range(self.num_layers + 2):
self.layers.append(Layer(num_nodes, channels))
self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(channels, num_classes)
def forward(self, x):
prev = cur = self.stem(x)
for layer in self.layers:
prev, cur = layer(prev, cur)
cur = self.gap(F.relu(cur)).view(x.size(0), -1)
return self.dense(cur)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
class NaiveSearchSpace(nn.Module):
def __init__(self, test_case):
super().__init__()
self.test_case = test_case
self.conv1 = LayerChoice([nn.Conv2d(3, 6, 3, padding=1), nn.Conv2d(3, 6, 5, padding=2)])
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = LayerChoice([nn.Conv2d(6, 16, 3, padding=1), nn.Conv2d(6, 16, 5, padding=2)],
return_mask=True)
self.conv3 = nn.Conv2d(16, 16, 1)
self.skipconnect = InputChoice(n_candidates=1)
self.skipconnect2 = InputChoice(n_candidates=2, return_mask=True)
self.bn = nn.BatchNorm2d(16)
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 10)
def forward(self, x):
bs = x.size(0)
x = self.pool(F.relu(self.conv1(x)))
x0, mask = self.conv2(x)
self.test_case.assertEqual(mask.size(), torch.Size([2]))
x1 = F.relu(self.conv3(x0))
_, mask = self.skipconnect2([x0, x1])
x0 = self.skipconnect([x0])
if x0 is not None:
x1 += x0
x = self.pool(self.bn(x1))
self.test_case.assertEqual(mask.size(), torch.Size([2]))
x = self.gap(x).view(bs, -1)
x = self.fc(x)
return x
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
class MutableOp(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.conv = nn.Conv2d(3, 120, kernel_size, padding=kernel_size // 2)
self.nested_mutable = InputChoice(n_candidates=10)
def forward(self, x):
return self.conv(x)
class NestedSpace(nn.Module):
# this doesn't pass tests
def __init__(self, test_case):
super().__init__()
self.test_case = test_case
self.conv1 = LayerChoice([MutableOp(3), MutableOp(5)])
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(120, 10)
def forward(self, x):
bs = x.size(0)
x = F.relu(self.conv1(x))
x = self.gap(x).view(bs, -1)
x = self.fc(x)
return x
...@@ -47,9 +47,9 @@ def _restore_io(): ...@@ -47,9 +47,9 @@ def _restore_io():
class AssessorTestCase(TestCase): class AssessorTestCase(TestCase):
def test_assessor(self): def test_assessor(self):
_reverse_io() _reverse_io()
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":2}') send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":2}') send(CommandType.ReportMetricData, '{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":3}') send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}')
send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED"}') send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED"}')
send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED"}') send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED"}')
send(CommandType.NewTrialJob, 'null') send(CommandType.NewTrialJob, 'null')
......
...@@ -135,12 +135,11 @@ class CompressorTestCase(TestCase): ...@@ -135,12 +135,11 @@ class CompressorTestCase(TestCase):
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2) layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0]) masks = pruner.calc_mask(layer, config_list[0], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1]) masks = pruner.calc_mask(layer, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.])) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
@tf2 @tf2
...@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase): ...@@ -159,7 +158,6 @@ class CompressorTestCase(TestCase):
assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) assert all(masks.sum((1)) == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1)
model.layers[2].set_weights([weights[0], weights[1].numpy()]) model.layers[2].set_weights([weights[0], weights[1].numpy()])
masks = pruner.calc_mask(layer, config_list[1]).numpy() masks = pruner.calc_mask(layer, config_list[1]).numpy()
masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0]) masks = masks.reshape((-1, masks.shape[-1])).transpose([1, 0])
...@@ -187,9 +185,9 @@ class CompressorTestCase(TestCase): ...@@ -187,9 +185,9 @@ class CompressorTestCase(TestCase):
model.conv1.weight.data = torch.tensor(w).float() model.conv1.weight.data = torch.tensor(w).float()
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1) layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1)
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2) layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
mask2 = pruner.calc_mask(layer2, config_list[1]) mask2 = pruner.calc_mask(layer2, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(mask1['weight'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.])) assert all(torch.sum(mask1['weight'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2['weight'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.])) assert all(torch.sum(mask2['weight'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))
...@@ -215,9 +213,9 @@ class CompressorTestCase(TestCase): ...@@ -215,9 +213,9 @@ class CompressorTestCase(TestCase):
pruner = torch_compressor.SlimPruner(model, config_list) pruner = torch_compressor.SlimPruner(model, config_list)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1) layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2) layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0]) mask2 = pruner.calc_mask(layer2, config_list[0], if_calculated=torch.tensor(0))
assert all(mask1['weight'].numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask1['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2['weight'].numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask2['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask1['bias'].numpy() == np.array([0., 1., 1., 1., 1.]))
...@@ -229,9 +227,9 @@ class CompressorTestCase(TestCase): ...@@ -229,9 +227,9 @@ class CompressorTestCase(TestCase):
pruner = torch_compressor.SlimPruner(model, config_list) pruner = torch_compressor.SlimPruner(model, config_list)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1) layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2) layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0]) mask2 = pruner.calc_mask(layer2, config_list[0], if_calculated=torch.tensor(0))
assert all(mask1['weight'].numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask1['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2['weight'].numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask2['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask1['bias'].numpy() == np.array([0., 0., 0., 1., 1.]))
...@@ -268,14 +266,14 @@ class CompressorTestCase(TestCase): ...@@ -268,14 +266,14 @@ class CompressorTestCase(TestCase):
# test ema # test ema
x = torch.tensor([[-0.2, 0], [0.1, 0.2]]) x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x) out = model.relu(x)
assert math.isclose(model.relu.tracked_min_biased, 0, abs_tol=eps) assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.002, abs_tol=eps) assert math.isclose(model.relu.module.tracked_max_biased, 0.002, abs_tol=eps)
quantizer.step() quantizer.step()
x = torch.tensor([[0.2, 0.4], [0.6, 0.8]]) x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
out = model.relu(x) out = model.relu(x)
assert math.isclose(model.relu.tracked_min_biased, 0.002, abs_tol=eps) assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.00998, abs_tol=eps) assert math.isclose(model.relu.module.tracked_max_biased, 0.00998, abs_tol=eps)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment