Unverified Commit 219d7d19 authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

Feature: support nn-Meter in Proxyless NAS (#4206)

parent 1b01a7e3
...@@ -13,14 +13,17 @@ To use ProxylessNAS training/searching approach, users need to specify search sp ...@@ -13,14 +13,17 @@ To use ProxylessNAS training/searching approach, users need to specify search sp
.. code-block:: python .. code-block:: python
trainer = ProxylessNasTrainer(model, trainer = ProxylessTrainer(model,
model_optim=optimizer, loss=LabelSmoothingLoss(),
train_loader=data_provider.train, dataset=None,
valid_loader=data_provider.valid, optimizer=optimizer,
device=device, metrics=lambda output, target: accuracy(output, target, topk=(1, 5,)),
warmup=True, num_epochs=120,
ckpt_path=args.checkpoint_path, log_frequency=10,
arch_path=args.arch_path) grad_reg_loss_type=args.grad_reg_loss_type,
grad_reg_loss_params=grad_reg_loss_params,
applied_hardware=args.applied_hardware, dummy_input=(1, 3, 224, 224),
ref_latency=args.reference_latency)
trainer.train() trainer.train()
trainer.export(args.arch_path) trainer.export(args.arch_path)
...@@ -30,44 +33,42 @@ The complete example code can be found :githublink:`here <examples/nas/oneshot/p ...@@ -30,44 +33,42 @@ The complete example code can be found :githublink:`here <examples/nas/oneshot/p
* **model** (*PyTorch model, required*\ ) - The model that users want to tune/search. It has mutables to specify search space. * **model** (*PyTorch model, required*\ ) - The model that users want to tune/search. It has mutables to specify search space.
* **model_optim** (*PyTorch optimizer, required*\ ) - The optimizer users want to train the model. * **metrics** (*PyTorch module, required*\ ) - The main term of the loss function for model train. Receives logits and ground truth label, return a loss tensor.
* **device** (*device, required*\ ) - The devices that users provide to do the train/search. The trainer applies data parallel on the model for users. * **optimizer** (*PyTorch Optimizer, required*\) - The optimizer used for optimizing the model.
* **train_loader** (*PyTorch data loader, required*\ ) - The data loader for training set. * **num_epochs** (*int, optional, default = 120*\ ) - The number of epochs to train/search.
* **valid_loader** (*PyTorch data loader, required*\ ) - The data loader for validation set. * **dataset** (*PyTorch dataset, required*\ ) - Dataset for training. Will be split for training weights and architecture weights.
* **label_smoothing** (*float, optional, default = 0.1*\ ) - The degree of label smoothing. * **warmup_epochs** (*int, optional, default = 0*\ ) - The number of epochs to do during warmup.
* **n_epochs** (*int, optional, default = 120*\ ) - The number of epochs to train/search. * **batch_size** (*int, optional, default = 64*\ ) - Batch size.
* **init_lr** (*float, optional, default = 0.025*\ ) - The initial learning rate for training the model. * **workers** (*int, optional, default = 4*\ ) - Workers for data loading.
* **binary_mode** (*'two', 'full', or 'full_v2', optional, default = 'full_v2'*\ ) - The forward/backward mode for the binary weights in mutator. 'full' means forward all the candidate ops, 'two' means only forward two sampled ops, 'full_v2' means recomputing the inactive ops during backward. * **device** (*device, optional, default = 'cpu'*\ ) - The devices that users provide to do the train/search. The trainer applies data parallel on the model for users.
* **arch_init_type** (*'normal' or 'uniform', optional, default = 'normal'*\ ) - The way to init architecture parameters. * **log_frequency** (*int, optional, default = None*\ ) - Step count per logging.
* **arch_init_ratio** (*float, optional, default = 1e-3*\ ) - The ratio to init architecture parameters. * **arc_learning_rate** (*float, optional, default = 1e-3*\ ) - The learning rate of the architecture parameters optimizer.
* **arch_optim_lr** (*float, optional, default = 1e-3*\ ) - The learning rate of the architecture parameters optimizer. * **grad_reg_loss_type** (*'mul#log', 'add#linear', or None, optional, default = 'add#linear'*\ ) - Regularization type to add hardware related loss. The trainer will not apply loss regularization when grad_reg_loss_type is set as None.
* **arch_weight_decay** (*float, optional, default = 0*\ ) - Weight decay of the architecture parameters optimizer. * **grad_reg_loss_params** (*dict, optional, default = None*\ ) - Regularization params. 'alpha' and 'beta' is required when ``grad_reg_loss_type`` is 'mul#log', 'lambda' is required when ``grad_reg_loss_type`` is 'add#linear'.
* **grad_update_arch_param_every** (*int, optional, default = 5*\ ) - Update architecture weights every this number of minibatches. * **applied_hardware** (*string, optional, default = None*\ ) - Applied hardware for to constraint the model's latency. Latency is predicted by Microsoft nn-Meter (https://github.com/microsoft/nn-Meter).
* **grad_update_steps** (*int, optional, default = 1*\ ) - During each update of architecture weights, the number of steps to train architecture weights. * **dummy_input** (*tuple, optional, default = (1, 3, 224, 224)*\ ) - The dummy input shape when applied to the target hardware.
* **warmup** (*bool, optional, default = True*\ ) - Whether to do warmup. * **ref_latency** (*float, optional, default = 65.0*\ ) - Reference latency value in the applied hardware (ms).
* **warmup_epochs** (*int, optional, default = 25*\ ) - The number of epochs to do during warmup.
* **arch_valid_frequency** (*int, optional, default = 1*\ ) - The frequency of printing validation result.
* **load_ckpt** (*bool, optional, default = False*\ ) - Whether to load checkpoint.
* **ckpt_path** (*str, optional, default = None*\ ) - checkpoint path, if load_ckpt is True, ckpt_path cannot be None.
* **arch_path** (*str, optional, default = None*\ ) - The path to store chosen architecture.
Implementation Implementation
-------------- --------------
The implementation on NNI is based on the `offical implementation <https://github.com/mit-han-lab/ProxylessNAS>`__. The official implementation supports two training approaches: gradient descent and RL based, and support different targeted hardware, including 'mobile', 'cpu', 'gpu8', 'flops'. In our current implementation on NNI, gradient descent training approach is supported, but has not supported different hardwares. The complete support is ongoing. The implementation on NNI is based on the `offical implementation <https://github.com/mit-han-lab/ProxylessNAS>`__. The official implementation supports two training approaches: gradient descent and RL based. In our current implementation on NNI, gradient descent training approach is supported. The complete support of ProxylessNAS is ongoing.
The official implementation supports different targeted hardware, including 'mobile', 'cpu', 'gpu8', 'flops'. In NNI repo, the hardware latency prediction is supported by `Microsoft nn-Meter <https://github.com/microsoft/nn-Meter>`__. nn-Meter is an accurate inference latency predictor for DNN models on diverse edge devices. nn-Meter support four hardwares up to now, including *'cortexA76cpu_tflite21'*, *'adreno640gpu_tflite21'*, *'adreno630gpu_tflite21'*, and *'myriadvpu_openvino2019r2'*. Users can find more information about nn-Meter on its website. More hardware will be supported in the future.
Below we will describe implementation details. Like other one-shot NAS algorithms on NNI, ProxylessNAS is composed of two parts: *search space* and *training approach*. For users to flexibly define their own search space and use built-in ProxylessNAS training approach, we put the specified search space in :githublink:`example code <examples/nas/oneshot/proxylessnas>` using :githublink:`NNI NAS interface <nni/algorithms/nas/pytorch/proxylessnas>`. Below we will describe implementation details. Like other one-shot NAS algorithms on NNI, ProxylessNAS is composed of two parts: *search space* and *training approach*. For users to flexibly define their own search space and use built-in ProxylessNAS training approach, we put the specified search space in :githublink:`example code <examples/nas/oneshot/proxylessnas>` using :githublink:`NNI NAS interface <nni/retiarii/oneshot/pytorch/proxyless>`.
.. image:: ../../img/proxylessnas.png .. image:: ../../img/proxylessnas.png
:target: ../../img/proxylessnas.png :target: ../../img/proxylessnas.png
:alt: :alt:
ProxylessNAS training approach is composed of ProxylessNasMutator and ProxylessNasTrainer. ProxylessNasMutator instantiates MixedOp for each mutable (i.e., LayerChoice), and manage architecture weights in MixedOp. **For DataParallel**\ , architecture weights should be included in user model. Specifically, in ProxylessNAS implementation, we add MixedOp to the corresponding mutable (i.e., LayerChoice) as a member variable. The mutator also exposes two member functions, i.e., ``arch_requires_grad``\ , ``arch_disable_grad``\ , for the trainer to control the training of architecture weights. ProxylessNAS training approach is composed of ProxylessLayerChoice and ProxylessNasTrainer. ProxylessLayerChoice instantiates MixedOp for each mutable (i.e., LayerChoice), and manage architecture weights in MixedOp. **For DataParallel**\ , architecture weights should be included in user model. Specifically, in ProxylessNAS implementation, we add MixedOp to the corresponding mutable (i.e., LayerChoice) as a member variable. The ProxylessLayerChoice class also exposes two member functions, i.e., ``resample``\ , ``finalize_grad``\ , for the trainer to control the training of architecture weights.
ProxylessNasMutator also implements the forward logic of the mutables (i.e., LayerChoice). ProxylessNasMutator also implements the forward logic of the mutables (i.e., LayerChoice).
Reproduce Results Reproduce Results
----------------- -----------------
To reproduce the result, we first run the search, we found that though it runs many epochs the chosen architecture converges at the first several epochs. This is probably induced by hyper-parameters or the implementation, we are working on it. The test accuracy of the found architecture is top1: 72.31, top5: 90.26. To reproduce the result, we first run the search, we found that though it runs many epochs the chosen architecture converges at the first several epochs. This is probably induced by hyper-parameters or the implementation, we are working on it.
\ No newline at end of file
...@@ -26,6 +26,12 @@ if __name__ == "__main__": ...@@ -26,6 +26,12 @@ if __name__ == "__main__":
parser.add_argument("--bn_eps", default=1e-3, type=float) parser.add_argument("--bn_eps", default=1e-3, type=float)
parser.add_argument("--dropout_rate", default=0, type=float) parser.add_argument("--dropout_rate", default=0, type=float)
parser.add_argument("--no_decay_keys", default='bn', type=str, choices=[None, 'bn', 'bn#bias']) parser.add_argument("--no_decay_keys", default='bn', type=str, choices=[None, 'bn', 'bn#bias'])
parser.add_argument('--grad_reg_loss_type', default='add#linear', type=str, choices=['add#linear', 'mul#log'])
parser.add_argument('--grad_reg_loss_lambda', default=1e-1, type=float) # grad_reg_loss_params
parser.add_argument('--grad_reg_loss_alpha', default=0.2, type=float) # grad_reg_loss_params
parser.add_argument('--grad_reg_loss_beta', default=0.3, type=float) # grad_reg_loss_params
parser.add_argument("--applied_hardware", default=None, type=str, help='the hardware to predict model latency')
parser.add_argument("--reference_latency", default=None, type=float, help='the reference latency in specified hardware')
# configurations of imagenet dataset # configurations of imagenet dataset
parser.add_argument("--data_path", default='/data/imagenet/', type=str) parser.add_argument("--data_path", default='/data/imagenet/', type=str)
parser.add_argument("--train_batch_size", default=256, type=int) parser.add_argument("--train_batch_size", default=256, type=int)
...@@ -81,8 +87,19 @@ if __name__ == "__main__": ...@@ -81,8 +87,19 @@ if __name__ == "__main__":
{'params': get_parameters(model, keys, mode='include'), 'weight_decay': 0}, {'params': get_parameters(model, keys, mode='include'), 'weight_decay': 0},
], lr=0.05, momentum=momentum, nesterov=nesterov) ], lr=0.05, momentum=momentum, nesterov=nesterov)
else: else:
momentum, nesterov = 0.9, True
optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5) optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)
if args.grad_reg_loss_type == 'add#linear':
grad_reg_loss_params = {'lambda': args.grad_reg_loss_lambda}
elif args.grad_reg_loss_type == 'mul#log':
grad_reg_loss_params = {
'alpha': args.grad_reg_loss_alpha,
'beta': args.grad_reg_loss_beta,
}
else:
args.grad_reg_loss_params = None
if args.train_mode == 'search': if args.train_mode == 'search':
from nni.retiarii.oneshot.pytorch import ProxylessTrainer from nni.retiarii.oneshot.pytorch import ProxylessTrainer
from torchvision.datasets import ImageNet from torchvision.datasets import ImageNet
...@@ -100,7 +117,11 @@ if __name__ == "__main__": ...@@ -100,7 +117,11 @@ if __name__ == "__main__":
optimizer=optimizer, optimizer=optimizer,
metrics=lambda output, target: accuracy(output, target, topk=(1, 5,)), metrics=lambda output, target: accuracy(output, target, topk=(1, 5,)),
num_epochs=120, num_epochs=120,
log_frequency=10) log_frequency=10,
grad_reg_loss_type=args.grad_reg_loss_type,
grad_reg_loss_params=grad_reg_loss_params,
applied_hardware=args.applied_hardware, dummy_input=(1, 3, 224, 224),
ref_latency=args.reference_latency)
trainer.fit() trainer.fit()
print('Final architecture:', trainer.export()) print('Final architecture:', trainer.export())
json.dump(trainer.export(), open('checkpoint.json', 'w')) json.dump(trainer.export(), open('checkpoint.json', 'w'))
......
import torch import torch
import torch.nn as nn import nni.retiarii.nn.pytorch as nn
import math import math
import ops import ops
import putils import putils
from nni.nas import pytorch as nas from nni.retiarii.nn.pytorch import LayerChoice
class SearchMobileNet(nn.Module): class SearchMobileNet(nn.Module):
def __init__(self, def __init__(self,
...@@ -57,11 +57,9 @@ class SearchMobileNet(nn.Module): ...@@ -57,11 +57,9 @@ class SearchMobileNet(nn.Module):
if stride == 1 and input_channel == width: if stride == 1 and input_channel == width:
# if it is not the first one # if it is not the first one
op_candidates += [ops.OPS['Zero'](input_channel, width, stride)] op_candidates += [ops.OPS['Zero'](input_channel, width, stride)]
conv_op = nas.mutables.LayerChoice(op_candidates, conv_op = LayerChoice(op_candidates, label="s{}_c{}".format(stage_cnt, i))
key="s{}_c{}".format(stage_cnt, i))
else: else:
conv_op = nas.mutables.LayerChoice(op_candidates, conv_op = LayerChoice(op_candidates, label="s{}_c{}".format(stage_cnt, i))
key="s{}_c{}".format(stage_cnt, i))
# shortcut # shortcut
if stride == 1 and input_channel == width: if stride == 1 and input_channel == width:
# if not first cell # if not first cell
......
from collections import OrderedDict from collections import OrderedDict
from nni.retiarii.serializer import basic_unit
import torch import torch
import torch.nn as nn import nni.retiarii.nn.pytorch as nn
from putils import get_same_padding, build_activation from putils import get_same_padding, build_activation
...@@ -35,14 +36,25 @@ class MobileInvertedResidualBlock(nn.Module): ...@@ -35,14 +36,25 @@ class MobileInvertedResidualBlock(nn.Module):
super(MobileInvertedResidualBlock, self).__init__() super(MobileInvertedResidualBlock, self).__init__()
self.mobile_inverted_conv = mobile_inverted_conv self.mobile_inverted_conv = mobile_inverted_conv
self.shortcut = shortcut
self.op_candidates_list = op_candidates_list self.op_candidates_list = op_candidates_list
self.zero_layer_module = ZeroLayerModule(shortcut)
def forward(self, x): def forward(self, x):
out = self.mobile_inverted_conv(x) out = self.mobile_inverted_conv(x)
if torch.sum(torch.abs(out)).item() == 0 and x.size() == out.size(): return self.zero_layer_module(x, out)
# is zero layer
return x
@basic_unit
class ZeroLayerModule(nn.Module):
def __init__(self, shortcut):
super().__init__()
self.shortcut = shortcut
def forward(self, x, out):
if torch.sum(torch.abs(out)).item() == 0:
if x.size() == out.size():
# is zero layer
return x
if self.shortcut is None: if self.shortcut is None:
return out return out
return out + self.shortcut(x) return out + self.shortcut(x)
...@@ -108,6 +120,7 @@ class Base2DLayer(nn.Module): ...@@ -108,6 +120,7 @@ class Base2DLayer(nn.Module):
self.add_module(key, modules['weight'][key]) self.add_module(key, modules['weight'][key])
else: else:
self.add_module(op, modules[op]) self.add_module(op, modules[op])
self.sequence = nn.Sequential(self._modules)
@property @property
def ops_list(self): def ops_list(self):
...@@ -120,14 +133,13 @@ class Base2DLayer(nn.Module): ...@@ -120,14 +133,13 @@ class Base2DLayer(nn.Module):
return True return True
elif op == 'weight': elif op == 'weight':
return False return False
raise ValueError('Invalid ops_order: %s' % self.ops_order) raise ValueError(f'Invalid ops_order: {self.ops_order}')
def weight_op(self): def weight_op(self):
raise NotImplementedError raise NotImplementedError
def forward(self, x): def forward(self, x):
for module in self._modules.values(): x = self.sequence(x)
x = module(x)
return x return x
@staticmethod @staticmethod
...@@ -224,6 +236,7 @@ class LinearLayer(nn.Module): ...@@ -224,6 +236,7 @@ class LinearLayer(nn.Module):
self.add_module(key, modules['weight'][key]) self.add_module(key, modules['weight'][key])
else: else:
self.add_module(op, modules[op]) self.add_module(op, modules[op])
self.sequence = nn.Sequential(self._modules)
@property @property
def ops_list(self): def ops_list(self):
...@@ -236,11 +249,10 @@ class LinearLayer(nn.Module): ...@@ -236,11 +249,10 @@ class LinearLayer(nn.Module):
return True return True
elif op == 'weight': elif op == 'weight':
return False return False
raise ValueError('Invalid ops_order: %s' % self.ops_order) raise ValueError(f'Invalid ops_order: {self.ops_order}')
def forward(self, x): def forward(self, x):
for module in self._modules.values(): x = self.sequence(x)
x = module(x)
return x return x
@staticmethod @staticmethod
...@@ -270,7 +282,7 @@ class MBInvertedConvLayer(nn.Module): ...@@ -270,7 +282,7 @@ class MBInvertedConvLayer(nn.Module):
feature_dim = self.mid_channels feature_dim = self.mid_channels
if self.expand_ratio == 1: if self.expand_ratio == 1:
self.inverted_bottleneck = None self.inverted_bottleneck = nn.Sequential()
else: else:
self.inverted_bottleneck = nn.Sequential(OrderedDict([ self.inverted_bottleneck = nn.Sequential(OrderedDict([
('conv', nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)), ('conv', nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
...@@ -291,8 +303,7 @@ class MBInvertedConvLayer(nn.Module): ...@@ -291,8 +303,7 @@ class MBInvertedConvLayer(nn.Module):
])) ]))
def forward(self, x): def forward(self, x):
if self.inverted_bottleneck: x = self.inverted_bottleneck(x)
x = self.inverted_bottleneck(x)
x = self.depth_conv(x) x = self.depth_conv(x)
x = self.point_linear(x) x = self.point_linear(x)
return x return x
......
import torch import torch
import torch.nn as nn import nni.retiarii.nn.pytorch as nn
def get_parameters(model, keys=None, mode='include'): def get_parameters(model, keys=None, mode='include'):
......
...@@ -5,7 +5,7 @@ import re ...@@ -5,7 +5,7 @@ import re
import torch import torch
from ..graph import Graph, Model, Node, Edge from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, Placeholder, LayerChoice from ..nn.pytorch import InputChoice, Placeholder, LayerChoice
from ..operation import Cell, Operation from ..operation import Cell, Operation
from ..serializer import get_init_parameters_or_fail from ..serializer import get_init_parameters_or_fail
...@@ -249,6 +249,15 @@ class GraphConverter: ...@@ -249,6 +249,15 @@ class GraphConverter:
return f'({left} < {right})' return f'({left} < {right})'
elif tensor.node().kind() == 'prim::If': elif tensor.node().kind() == 'prim::If':
raise RuntimeError('Have not supported `if A and/or B`, please use two `if` statements instead.') raise RuntimeError('Have not supported `if A and/or B`, please use two `if` statements instead.')
elif tensor.node().kind() == 'aten::abs':
value = _generate_expr(tensor.node().inputsAt(0))
return f'(torch.abs({value}))'
elif tensor.node().kind() == 'aten::sum':
value = _generate_expr(tensor.node().inputsAt(0))
return f'(torch.sum({value}))'
elif tensor.node().kind() == 'aten::item':
value = _generate_expr(tensor.node().inputsAt(0))
return f'({value}.item())'
else: else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition, ' raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition, '
'you are suggested to decorate the corresponding class with "@basic_unit".') 'you are suggested to decorate the corresponding class with "@basic_unit".')
...@@ -712,7 +721,7 @@ class GraphConverterWithShape(GraphConverter): ...@@ -712,7 +721,7 @@ class GraphConverterWithShape(GraphConverter):
# trace each layerchoice # trace each layerchoice
for name, submodule in module.named_modules(): for name, submodule in module.named_modules():
# TODO: support InputChoice and ValueChioce # TODO: support InputChoice and ValueChoice
if isinstance(submodule, LayerChoice): if isinstance(submodule, LayerChoice):
full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name) full_name = get_full_name_by_scope_name(ir_model, name.split('.'), module_name)
lc_node = ir_model.get_node_by_name(full_name) lc_node = ir_model.get_node_by_name(full_name)
...@@ -768,66 +777,6 @@ class GraphConverterWithShape(GraphConverter): ...@@ -768,66 +777,6 @@ class GraphConverterWithShape(GraphConverter):
for node in ir_model.get_nodes(): for node in ir_model.get_nodes():
propagate_shape_for_graph(node.graph) propagate_shape_for_graph(node.graph)
def flatten(self, ir_model: 'Model'):
"""
Flatten the subgraph into root graph.
"""
def _flatten(graph: 'Graph'):
"""
flatten this graph
"""
model = graph.model
node_to_remove = []
for node in graph.hidden_nodes:
node_graph = model.graphs.get(node.name)
if node_graph is not None:
_flatten(node_graph)
# flatten node graph into this graph
id_to_new_node = {}
for node_graph_node in node_graph.hidden_nodes:
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
new_node.update_label(node_graph_node.label)
new_node._register()
id_to_new_node[new_node.id] = new_node
# reconnect node edges
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for input_node_edge in node_graph.input_node.outgoing_edges:
if input_node_edge.head_slot == in_edge.tail_slot:
graph.add_edge(
head=(in_edge.head, in_edge.head_slot),
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
for output_node_edge in node_graph.output_node.incoming_edges:
if output_node_edge.head_slot == out_edge.tail_slot:
graph.add_edge(
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
tail=(out_edge.tail, out_edge.tail_slot))
for edge in node_graph.edges:
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
continue
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
node_to_remove.append(node)
del model.graphs[node.name]
for node in node_to_remove:
node.remove()
_flatten(ir_model.root_graph)
# remove subgraphs
ir_model.graphs = {ir_model._root_graph_name: ir_model.root_graph}
def _trace(self, module, dummy_input): def _trace(self, module, dummy_input):
traced_module = torch.jit.trace(module, dummy_input) traced_module = torch.jit.trace(module, dummy_input)
torch._C._jit_pass_inline(traced_module.graph) torch._C._jit_pass_inline(traced_module.graph)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from ..operation import Cell from ..operation import Cell
from ..graph import Model, Node from ..graph import Model, Graph, Node, Edge
def build_full_name(prefix, name, seq=None): def build_full_name(prefix, name, seq=None):
...@@ -110,3 +110,138 @@ def match_node(ir_model: Model, torch_node, prefix=''): ...@@ -110,3 +110,138 @@ def match_node(ir_model: Model, torch_node, prefix=''):
def _without_shape_info(node: Node): def _without_shape_info(node: Node):
return not node.operation.attributes['input_shape'] and not node.operation.attributes['output_shape'] return not node.operation.attributes['input_shape'] and not node.operation.attributes['output_shape']
def flatten_model_graph(ir_model: Model):
"""
Flatten the subgraph into root graph.
"""
def _flatten(graph: Graph):
"""
flatten this graph
"""
model = graph.model
node_to_remove = []
for node in graph.hidden_nodes:
node_graph = model.graphs.get(node.name)
if node_graph is not None:
_flatten(node_graph)
# flatten node graph into this graph
id_to_new_node = {}
for node_graph_node in node_graph.hidden_nodes:
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
new_node.update_label(node_graph_node.label)
new_node._register()
id_to_new_node[new_node.id] = new_node
# reconnect node edges
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for input_node_edge in node_graph.input_node.outgoing_edges:
if input_node_edge.head_slot == in_edge.tail_slot:
graph.add_edge(
head=(in_edge.head, in_edge.head_slot),
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
for output_node_edge in node_graph.output_node.incoming_edges:
if output_node_edge.head_slot == out_edge.tail_slot:
graph.add_edge(
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
tail=(out_edge.tail, out_edge.tail_slot))
for edge in node_graph.edges:
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
continue
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
node_to_remove.append(node)
del model.graphs[node.name]
for node in node_to_remove:
node.remove()
new_ir_model = ir_model.fork()
_flatten(new_ir_model.root_graph)
# remove subgraphs
new_ir_model.graphs = {new_ir_model._root_graph_name: new_ir_model.root_graph}
return new_ir_model
def flatten_model_graph_without_layerchoice(ir_model: Model):
"""
Flatten the subgraph into root graph and jump all layerchoice
"""
def _flatten_without_layerchoice(graph: Graph):
"""
flatten this graph
"""
model = graph.model
node_to_remove = []
for node in graph.hidden_nodes:
if is_layerchoice_node(node):
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
del model.graphs[node.name]
node.remove()
return
node_graph = model.graphs.get(node.name)
if node_graph is not None:
_flatten_without_layerchoice(node_graph)
# flatten node graph into this graph
id_to_new_node = {}
for node_graph_node in node_graph.hidden_nodes:
new_node = Node(graph, node_graph_node.id, node_graph_node.name, node_graph_node.operation, _internal=True)
new_node.update_label(node_graph_node.label)
new_node._register()
id_to_new_node[new_node.id] = new_node
# reconnect node edges
for in_edge in node.incoming_edges:
graph.del_edge(in_edge)
for input_node_edge in node_graph.input_node.outgoing_edges:
if input_node_edge.head_slot == in_edge.tail_slot:
graph.add_edge(
head=(in_edge.head, in_edge.head_slot),
tail=(id_to_new_node[input_node_edge.tail.id], input_node_edge.tail_slot))
for out_edge in node.outgoing_edges:
graph.del_edge(out_edge)
for output_node_edge in node_graph.output_node.incoming_edges:
if output_node_edge.head_slot == out_edge.tail_slot:
graph.add_edge(
head=(id_to_new_node[output_node_edge.head.id], output_node_edge.head_slot),
tail=(out_edge.tail, out_edge.tail_slot))
for edge in node_graph.edges:
if edge.head == node_graph.input_node or edge.tail == node_graph.output_node:
continue
new_head = id_to_new_node[edge.head.id]
new_tail = id_to_new_node[edge.tail.id]
Edge((new_head, edge.head_slot), (new_tail, edge.tail_slot), _internal=True)._register()
node_to_remove.append(node)
del model.graphs[node.name]
for node in node_to_remove:
node.remove()
new_ir_model = ir_model.fork()
_flatten_without_layerchoice(new_ir_model.root_graph)
# remove subgraphs
new_ir_model.graphs = {new_ir_model._root_graph_name: new_ir_model.root_graph}
return new_ir_model
...@@ -180,7 +180,7 @@ class Model: ...@@ -180,7 +180,7 @@ class Model:
There could be multiple nodes with the same label. Name space name can uniquely There could be multiple nodes with the same label. Name space name can uniquely
identify a graph or node. identify a graph or node.
NOTE: the implementation does not support the class abstration NOTE: the implementation does not support the class abstraction
""" """
matched_nodes = [] matched_nodes = []
for graph in self.graphs.values(): for graph in self.graphs.values():
...@@ -212,6 +212,13 @@ class Model: ...@@ -212,6 +212,13 @@ class Model:
else: else:
return None return None
def get_cell_nodes(self) -> List['Node']:
matched_nodes = []
for graph in self.graphs.values():
nodes = [node for node in graph.nodes if isinstance(node.operation, Cell)]
matched_nodes.extend(nodes)
return matched_nodes
class ModelStatus(Enum): class ModelStatus(Enum):
""" """
......
...@@ -95,12 +95,84 @@ class ProxylessLayerChoice(nn.Module): ...@@ -95,12 +95,84 @@ class ProxylessLayerChoice(nn.Module):
def export(self): def export(self):
return torch.argmax(self.alpha).item() return torch.argmax(self.alpha).item()
def export_prob(self):
return F.softmax(self.alpha, dim=-1)
class ProxylessInputChoice(nn.Module): class ProxylessInputChoice(nn.Module):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise NotImplementedError('Input choice is not supported for ProxylessNAS.') raise NotImplementedError('Input choice is not supported for ProxylessNAS.')
class HardwareLatencyEstimator:
def __init__(self, applied_hardware, model, dummy_input=(1, 3, 224, 224), dump_lat_table='data/latency_table.yaml'):
import nn_meter # pylint: disable=import-error
_logger.info(f'Load latency predictor for applied hardware: {applied_hardware}.')
self.predictor_name = applied_hardware
self.latency_predictor = nn_meter.load_latency_predictor(applied_hardware)
self.block_latency_table = self._form_latency_table(model, dummy_input, dump_lat_table=dump_lat_table)
def _form_latency_table(self, model, dummy_input, dump_lat_table):
latency_table = {}
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.converter.graph_gen import GraphConverterWithShape
from nni.retiarii.converter.utils import flatten_model_graph_without_layerchoice, is_layerchoice_node
script_module = torch.jit.script(model)
base_model_ir = convert_to_graph(script_module, model,
converter=GraphConverterWithShape(), dummy_input=torch.randn(*dummy_input))
# form the latency of layerchoice blocks for the latency table
temp_ir_model = base_model_ir.fork()
cell_nodes = base_model_ir.get_cell_nodes()
layerchoice_nodes = [node for node in cell_nodes if is_layerchoice_node(node)]
for lc_node in layerchoice_nodes:
cand_lat = {}
for candidate in lc_node.operation.parameters['candidates']:
node_graph = base_model_ir.graphs.get(candidate)
if node_graph is not None:
temp_ir_model._root_graph_name = node_graph.name
latency = self.latency_predictor.predict(temp_ir_model, model_type = 'nni-ir')
else:
_logger.warning(f"Could not found graph for layerchoice candidate {candidate}")
latency = 0
cand_lat[candidate.split('_')[-1]] = float(latency)
latency_table[lc_node.operation.parameters['label']] = cand_lat
# form the latency of the stationary block in the latency table
temp_ir_model._root_graph_name = base_model_ir._root_graph_name
temp_ir_model = flatten_model_graph_without_layerchoice(temp_ir_model)
latency = self.latency_predictor.predict(temp_ir_model, model_type = 'nni-ir')
latency_table['stationary_block'] = {'root': float(latency)}
# save latency table
if dump_lat_table:
import os, yaml
os.makedirs(os.path.dirname(dump_lat_table), exist_ok=True)
with open(dump_lat_table, 'a') as fp:
yaml.dump([{
"applied_hardware": self.predictor_name,
'latency_table': latency_table
}], fp)
_logger.info("Latency lookup table form done")
return latency_table
def cal_expected_latency(self, current_architecture_prob):
lat = self.block_latency_table['stationary_block']['root']
for module_name, probs in current_architecture_prob.items():
assert len(probs) == len(self.block_latency_table[module_name])
lat += torch.sum(torch.tensor([probs[i] * self.block_latency_table[module_name][str(i)]
for i in range(len(probs))]))
return lat
def export_latency(self, current_architecture):
lat = self.block_latency_table['stationary_block']['root']
for module_name, selected_module in current_architecture.items():
lat += self.block_latency_table[module_name][str(selected_module)]
return lat
class ProxylessTrainer(BaseOneShotTrainer): class ProxylessTrainer(BaseOneShotTrainer):
""" """
Proxyless trainer. Proxyless trainer.
...@@ -131,12 +203,31 @@ class ProxylessTrainer(BaseOneShotTrainer): ...@@ -131,12 +203,31 @@ class ProxylessTrainer(BaseOneShotTrainer):
Step count per logging. Step count per logging.
arc_learning_rate : float arc_learning_rate : float
Learning rate of architecture parameters. Learning rate of architecture parameters.
grad_reg_loss_type: string
Regularization type to add hardware related loss, allowed types include
- ``"mul#log"``: ``regularized_loss = (torch.log(expected_latency) / math.log(self.ref_latency)) ** beta``
- ``"add#linear"``: ``regularized_loss = reg_lambda * (expected_latency - self.ref_latency) / self.ref_latency``
- None: do not apply loss regularization.
grad_reg_loss_params: dict
Regularization params, allowed params include
- ``"alpha"`` and ``"beta"`` is required when ``grad_reg_loss_type == "mul#log"``
- ``"lambda"`` is required when ``grad_reg_loss_type == "add#linear"``
applied_hardware: string
Applied hardware for to constraint the model's latency. Latency is predicted by Microsoft
nn-Meter (https://github.com/microsoft/nn-Meter).
dummy_input: tuple
The dummy input shape when applied to the target hardware.
ref_latency: float
Reference latency value in the applied hardware (ms).
""" """
def __init__(self, model, loss, metrics, optimizer, def __init__(self, model, loss, metrics, optimizer,
num_epochs, dataset, warmup_epochs=0, num_epochs, dataset, warmup_epochs=0,
batch_size=64, workers=4, device=None, log_frequency=None, batch_size=64, workers=4, device=None, log_frequency=None,
arc_learning_rate=1.0E-3): arc_learning_rate=1.0E-3,
grad_reg_loss_type=None, grad_reg_loss_params=None,
applied_hardware=None, dummy_input=(1, 3, 224, 224),
ref_latency=65.0):
self.model = model self.model = model
self.loss = loss self.loss = loss
self.metrics = metrics self.metrics = metrics
...@@ -148,8 +239,17 @@ class ProxylessTrainer(BaseOneShotTrainer): ...@@ -148,8 +239,17 @@ class ProxylessTrainer(BaseOneShotTrainer):
self.workers = workers self.workers = workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.log_frequency = log_frequency self.log_frequency = log_frequency
self.model.to(self.device)
# latency predictor
if applied_hardware:
self.latency_estimator = HardwareLatencyEstimator(applied_hardware, self.model, dummy_input)
else:
self.latency_estimator = None
self.reg_loss_type = grad_reg_loss_type
self.reg_loss_params = {} if grad_reg_loss_params is None else grad_reg_loss_params
self.ref_latency = ref_latency
self.model.to(self.device)
self.nas_modules = [] self.nas_modules = []
replace_layer_choice(self.model, ProxylessLayerChoice, self.nas_modules) replace_layer_choice(self.model, ProxylessLayerChoice, self.nas_modules)
replace_input_choice(self.model, ProxylessInputChoice, self.nas_modules) replace_input_choice(self.model, ProxylessInputChoice, self.nas_modules)
...@@ -189,7 +289,7 @@ class ProxylessTrainer(BaseOneShotTrainer): ...@@ -189,7 +289,7 @@ class ProxylessTrainer(BaseOneShotTrainer):
for _, module in self.nas_modules: for _, module in self.nas_modules:
module.resample() module.resample()
self.ctrl_optim.zero_grad() self.ctrl_optim.zero_grad()
logits, loss = self._logits_and_loss(val_X, val_y) logits, loss = self._logits_and_loss_for_arch_update(val_X, val_y)
loss.backward() loss.backward()
for _, module in self.nas_modules: for _, module in self.nas_modules:
module.finalize_grad() module.finalize_grad()
...@@ -199,21 +299,60 @@ class ProxylessTrainer(BaseOneShotTrainer): ...@@ -199,21 +299,60 @@ class ProxylessTrainer(BaseOneShotTrainer):
for _, module in self.nas_modules: for _, module in self.nas_modules:
module.resample() module.resample()
self.optimizer.zero_grad() self.optimizer.zero_grad()
logits, loss = self._logits_and_loss(trn_X, trn_y) logits, loss = self._logits_and_loss_for_weight_update(trn_X, trn_y)
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
metrics = self.metrics(logits, trn_y) metrics = self.metrics(logits, trn_y)
metrics["loss"] = loss.item() metrics["loss"] = loss.item()
if self.latency_estimator:
metrics["latency"] = self._export_latency()
meters.update(metrics) meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0: if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, _logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters) self.num_epochs, step + 1, len(self.train_loader), meters)
def _logits_and_loss(self, X, y): def _logits_and_loss_for_arch_update(self, X, y):
''' return logits and loss for architecture parameter update '''
logits = self.model(X)
ce_loss = self.loss(logits, y)
if not self.latency_estimator:
return logits, ce_loss
current_architecture_prob = {}
for module_name, module in self.nas_modules:
probs = module.export_prob()
current_architecture_prob[module_name] = probs
expected_latency = self.latency_estimator.cal_expected_latency(current_architecture_prob)
if self.reg_loss_type == 'mul#log':
import math
alpha = self.reg_loss_params.get('alpha', 1)
beta = self.reg_loss_params.get('beta', 0.6)
# noinspection PyUnresolvedReferences
reg_loss = (torch.log(expected_latency) / math.log(self.ref_latency)) ** beta
return logits, alpha * ce_loss * reg_loss
elif self.reg_loss_type == 'add#linear':
reg_lambda = self.reg_loss_params.get('lambda', 2e-1)
reg_loss = reg_lambda * (expected_latency - self.ref_latency) / self.ref_latency
return logits, ce_loss + reg_loss
elif self.reg_loss_type is None:
return logits, ce_loss
else:
raise ValueError(f'Do not support: {self.reg_loss_type}')
def _logits_and_loss_for_weight_update(self, X, y):
''' return logits and loss for weight parameter update '''
logits = self.model(X) logits = self.model(X)
loss = self.loss(logits, y) loss = self.loss(logits, y)
return logits, loss return logits, loss
def _export_latency(self):
current_architecture = {}
for module_name, module in self.nas_modules:
selected_module = module.export()
current_architecture[module_name] = selected_module
return self.latency_estimator.export_latency(current_architecture)
def fit(self): def fit(self):
for i in range(self.num_epochs): for i in range(self.num_epochs):
self._train_one_epoch(i) self._train_one_epoch(i)
......
...@@ -121,7 +121,7 @@ class SinglePathTrainer(BaseOneShotTrainer): ...@@ -121,7 +121,7 @@ class SinglePathTrainer(BaseOneShotTrainer):
def __init__(self, model, loss, metrics, def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid, optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None): batch_size=64, workers=4, device=None, log_frequency=None):
self.model = model self.model = model
self.loss = loss self.loss = loss
self.metrics = metrics self.metrics = metrics
......
...@@ -1387,5 +1387,13 @@ class TestOperators(unittest.TestCase, ConvertMixin): ...@@ -1387,5 +1387,13 @@ class TestOperators(unittest.TestCase, ConvertMixin):
x = torch.randn(20, 5, 10, 10) x = torch.randn(20, 5, 10, 10)
self.checkExportImport(SimpleOp(), (x, )) self.checkExportImport(SimpleOp(), (x, ))
def test_basic_abs(self):
class SimpleOp(nn.Module):
def forward(self, x):
out = torch.abs(x)
return out
x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
self.checkExportImport(SimpleOp(), (x, ))
class TestOperatorsWithShape(TestOperators, ConvertWithShapeMixin): class TestOperatorsWithShape(TestOperators, ConvertWithShapeMixin):
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment