Unverified Commit 468917ca authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Merge pull request #3155 from microsoft/dev-retiarii

[Do NOT Squash] Merge retiarii dev branch to master
parents f8424a9f d5a551c8
...@@ -98,3 +98,6 @@ venv.bak/ ...@@ -98,3 +98,6 @@ venv.bak/
# VSCode # VSCode
.vscode .vscode
.vs .vs
.history
generated/
test/ut/retiarii/_debug_graph_data.json
...@@ -11,9 +11,9 @@ import torch.nn as nn ...@@ -11,9 +11,9 @@ import torch.nn as nn
import datasets import datasets
from model import CNN from model import CNN
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.algorithms.nas.pytorch.darts import DartsTrainer
from utils import accuracy from utils import accuracy
logger = logging.getLogger('nni') logger = logging.getLogger('nni')
if __name__ == "__main__": if __name__ == "__main__":
...@@ -25,6 +25,7 @@ if __name__ == "__main__": ...@@ -25,6 +25,7 @@ if __name__ == "__main__":
parser.add_argument("--channels", default=16, type=int) parser.add_argument("--channels", default=16, type=int)
parser.add_argument("--unrolled", default=False, action="store_true") parser.add_argument("--unrolled", default=False, action="store_true")
parser.add_argument("--visualization", default=False, action="store_true") parser.add_argument("--visualization", default=False, action="store_true")
parser.add_argument("--v1", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10") dataset_train, dataset_valid = datasets.get_dataset("cifar10")
...@@ -35,6 +36,8 @@ if __name__ == "__main__": ...@@ -35,6 +36,8 @@ if __name__ == "__main__":
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
if args.v1:
from nni.algorithms.nas.pytorch.darts import DartsTrainer
trainer = DartsTrainer(model, trainer = DartsTrainer(model,
loss=criterion, loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)), metrics=lambda output, target: accuracy(output, target, topk=(1,)),
...@@ -48,4 +51,20 @@ if __name__ == "__main__": ...@@ -48,4 +51,20 @@ if __name__ == "__main__":
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
if args.visualization: if args.visualization:
trainer.enable_visualization() trainer.enable_visualization()
trainer.train() trainer.train()
else:
from nni.retiarii.trainer.pytorch import DartsTrainer
trainer = DartsTrainer(
model=model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optim,
num_epochs=args.epochs,
dataset=dataset_train,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
unrolled=args.unrolled
)
trainer.fit()
print('Final architecture:', trainer.export())
...@@ -48,7 +48,13 @@ class Cell(nn.Module): ...@@ -48,7 +48,13 @@ class Cell(nn.Module):
], key=cell_name + "_op") ], key=cell_name + "_op")
def forward(self, prev_layers): def forward(self, prev_layers):
chosen_input, chosen_mask = self.input_choice(prev_layers) from nni.retiarii.trainer.pytorch.random import PathSamplingInputChoice
out = self.input_choice(prev_layers)
if isinstance(self.input_choice, PathSamplingInputChoice):
# Retiarii pattern
return out, self.input_choice.mask
else:
chosen_input, chosen_mask = out
cell_out = self.op_choice(chosen_input) cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask return cell_out, chosen_mask
......
...@@ -26,17 +26,22 @@ if __name__ == "__main__": ...@@ -26,17 +26,22 @@ if __name__ == "__main__":
parser.add_argument("--search-for", choices=["macro", "micro"], default="macro") parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
parser.add_argument("--visualization", default=False, action="store_true") parser.add_argument("--visualization", default=False, action="store_true")
parser.add_argument("--v1", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10") dataset_train, dataset_valid = datasets.get_dataset("cifar10")
mutator = None
ctrl_kwargs = {}
if args.search_for == "macro": if args.search_for == "macro":
model = GeneralNetwork() model = GeneralNetwork()
num_epochs = args.epochs or 310 num_epochs = args.epochs or 310
mutator = None
elif args.search_for == "micro": elif args.search_for == "micro":
model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=False)
num_epochs = args.epochs or 150 num_epochs = args.epochs or 150
if args.v1:
mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True) mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
else:
ctrl_kwargs = {"tanh_constant": 1.1}
else: else:
raise AssertionError raise AssertionError
...@@ -44,6 +49,7 @@ if __name__ == "__main__": ...@@ -44,6 +49,7 @@ if __name__ == "__main__":
optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)
if args.v1:
trainer = enas.EnasTrainer(model, trainer = enas.EnasTrainer(model,
loss=criterion, loss=criterion,
metrics=accuracy, metrics=accuracy,
...@@ -59,3 +65,16 @@ if __name__ == "__main__": ...@@ -59,3 +65,16 @@ if __name__ == "__main__":
if args.visualization: if args.visualization:
trainer.enable_visualization() trainer.enable_visualization()
trainer.train() trainer.train()
else:
from nni.retiarii.trainer.pytorch.enas import EnasTrainer
trainer = EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optimizer,
batch_size=args.batch_size,
num_epochs=num_epochs,
dataset=dataset_train,
log_frequency=args.log_frequency,
ctrl_kwargs=ctrl_kwargs)
trainer.fit()
...@@ -6,7 +6,7 @@ import torchvision ...@@ -6,7 +6,7 @@ import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
from nni.nas.pytorch.mutables import LayerChoice, InputChoice from nni.nas.pytorch.mutables import LayerChoice, InputChoice
from nni.nas.pytorch.darts import DartsTrainer from nni.algorithms.nas.pytorch.darts import DartsTrainer
class Net(nn.Module): class Net(nn.Module):
......
import logging
import os import os
import sys import sys
import logging
from argparse import ArgumentParser from argparse import ArgumentParser
import torch import torch
import datasets from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from torchvision import transforms
from putils import get_parameters import datasets
from model import SearchMobileNet from model import SearchMobileNet
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from putils import LabelSmoothingLoss, accuracy, get_parameters
from retrain import Retrain from retrain import Retrain
logger = logging.getLogger('nni_proxylessnas') logger = logging.getLogger('nni_proxylessnas')
...@@ -30,7 +33,7 @@ if __name__ == "__main__": ...@@ -30,7 +33,7 @@ if __name__ == "__main__":
parser.add_argument("--resize_scale", default=0.08, type=float) parser.add_argument("--resize_scale", default=0.08, type=float)
parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None']) parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None'])
# configurations for training mode # configurations for training mode
parser.add_argument("--train_mode", default='search', type=str, choices=['search', 'retrain']) parser.add_argument("--train_mode", default='search', type=str, choices=['search_v1', 'search', 'retrain'])
# configurations for search # configurations for search
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str) parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
parser.add_argument("--arch_path", default='./arch_path.pt', type=str) parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
...@@ -80,6 +83,26 @@ if __name__ == "__main__": ...@@ -80,6 +83,26 @@ if __name__ == "__main__":
optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5) optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)
if args.train_mode == 'search': if args.train_mode == 'search':
from nni.retiarii.trainer.pytorch import ProxylessTrainer
from torchvision.datasets import ImageNet
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = ImageNet(args.data_path, transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
trainer = ProxylessTrainer(model,
loss=LabelSmoothingLoss(),
dataset=dataset,
optimizer=optimizer,
metrics=lambda output, target: accuracy(output, target, topk=(1, 5,)),
num_epochs=120,
log_frequency=10)
trainer.fit()
print('Final architecture:', trainer.export())
elif args.train_mode == 'search_v1':
# this is architecture search # this is architecture search
logger.info('Creating ProxylessNasTrainer...') logger.info('Creating ProxylessNasTrainer...')
trainer = ProxylessNasTrainer(model, trainer = ProxylessNasTrainer(model,
......
...@@ -58,11 +58,9 @@ class SearchMobileNet(nn.Module): ...@@ -58,11 +58,9 @@ class SearchMobileNet(nn.Module):
# if it is not the first one # if it is not the first one
op_candidates += [ops.OPS['Zero'](input_channel, width, stride)] op_candidates += [ops.OPS['Zero'](input_channel, width, stride)]
conv_op = nas.mutables.LayerChoice(op_candidates, conv_op = nas.mutables.LayerChoice(op_candidates,
return_mask=True,
key="s{}_c{}".format(stage_cnt, i)) key="s{}_c{}".format(stage_cnt, i))
else: else:
conv_op = nas.mutables.LayerChoice(op_candidates, conv_op = nas.mutables.LayerChoice(op_candidates,
return_mask=True,
key="s{}_c{}".format(stage_cnt, i)) key="s{}_c{}".format(stage_cnt, i))
# shortcut # shortcut
if stride == 1 and input_channel == width: if stride == 1 and input_channel == width:
......
...@@ -39,19 +39,13 @@ class MobileInvertedResidualBlock(nn.Module): ...@@ -39,19 +39,13 @@ class MobileInvertedResidualBlock(nn.Module):
self.op_candidates_list = op_candidates_list self.op_candidates_list = op_candidates_list
def forward(self, x): def forward(self, x):
out, idx = self.mobile_inverted_conv(x) out = self.mobile_inverted_conv(x)
# TODO: unify idx format if torch.sum(torch.abs(out)).item() == 0 and x.size() == out.size():
if not isinstance(idx, int): # is zero layer
idx = (idx == 1).nonzero() return x
if self.op_candidates_list[idx].is_zero_layer(): if self.shortcut is None:
res = x return out
elif self.shortcut is None: return out + self.shortcut(x)
res = out
else:
conv_x = out
skip_x = self.shortcut(x)
res = skip_x + conv_x
return res
class ShuffleLayer(nn.Module): class ShuffleLayer(nn.Module):
......
import torch
import torch.nn as nn import torch.nn as nn
def get_parameters(model, keys=None, mode='include'): def get_parameters(model, keys=None, mode='include'):
if keys is None: if keys is None:
for name, param in model.named_parameters(): for name, param in model.named_parameters():
...@@ -36,6 +38,7 @@ def get_same_padding(kernel_size): ...@@ -36,6 +38,7 @@ def get_same_padding(kernel_size):
assert kernel_size % 2 > 0, 'kernel size should be odd number' assert kernel_size % 2 > 0, 'kernel size should be odd number'
return kernel_size // 2 return kernel_size // 2
def build_activation(act_func, inplace=True): def build_activation(act_func, inplace=True):
if act_func == 'relu': if act_func == 'relu':
return nn.ReLU(inplace=inplace) return nn.ReLU(inplace=inplace)
...@@ -65,3 +68,40 @@ def make_divisible(v, divisor, min_val=None): ...@@ -65,3 +68,40 @@ def make_divisible(v, divisor, min_val=None):
if new_v < 0.9 * v: if new_v < 0.9 * v:
new_v += divisor new_v += divisor
return new_v return new_v
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
class LabelSmoothingLoss(nn.Module):
def __init__(self, smoothing=0.1, dim=-1):
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.dim = dim
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
num_classes = pred.size(self.dim)
with torch.no_grad():
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (num_classes - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
...@@ -109,7 +109,13 @@ class MutableScope(Mutable): ...@@ -109,7 +109,13 @@ class MutableScope(Mutable):
def __init__(self, key): def __init__(self, key):
super().__init__(key=key) super().__init__(key=key)
def _check_built(self):
return True # bypass the test because it's deprecated
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if not hasattr(self, 'mutator'):
return super().__call__(*args, **kwargs)
warnings.warn("`MutableScope` is deprecated in Retiarii.", DeprecationWarning)
try: try:
self._check_built() self._check_built()
self.mutator.enter_mutable_scope(self) self.mutator.enter_mutable_scope(self)
......
from .operation import Operation
from .graph import *
from .execution import *
from .mutator import *
from .utils import register_module
\ No newline at end of file
from .pytorch import model_to_pytorch_script
import logging
from typing import List
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
_logger = logging.getLogger(__name__)
def model_to_pytorch_script(model: Model, placement=None) -> str:
graphs = []
total_pkgs = set()
for name, cell in model.graphs.items():
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: %s', str(edges))
if not edges:
return []
_logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_slot))
if [edge.tail_slot for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> List[str]:
edges = _sorted_incoming_edges(node)
inputs = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int)
if edge.head.operation.io_names is not None:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs.append(edge.head.operation.io_names[edge.head_slot])
else:
# when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot))
else:
if edge.head_slot is None:
# when the input comes from a single-output operator
inputs.append('{}'.format(edge.head.name))
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return inputs
def _remove_prefix(names, graph_name):
"""
variables name (full name space) is too long,
shorten the name by removing the prefix ```graph_name```
"""
if isinstance(names, list):
converted_names = []
for name in names:
if name.startswith(graph_name):
converted_names.append(name[len(graph_name):])
else:
converted_names.append(name)
return converted_names
else:
return names[len(graph_name):] if names.startswith(graph_name) else names
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str:
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
for node in nodes:
if node.operation:
pkg_name = node.operation.get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name))
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
node_codes.append(f"{node_code}.to('{placement[node].device}')")
else:
node_codes.append(node_code)
if graph.input_node.operation.io_names is None:
input_code = '*_inputs'
else:
for name in graph.input_node.operation.io_names:
assert not name.startswith(graph_name)
input_code = ', '.join(graph.input_node.operation.io_names)
edge_codes = []
sorted_nodes = graph.topo_sort()
for node in sorted_nodes:
if node.operation:
inputs = _format_inputs(node)
inputs = _remove_prefix(inputs, graph_name)
node_name = _remove_prefix(node.name, graph_name)
edge_codes.append(node.operation.to_forward_code(node_name, node_name, inputs))
output_names = _format_inputs(graph.output_node)
output_names = _remove_prefix(output_names, graph_name)
if not output_names:
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
output_code = ', '.join(output_names)
linebreak = '\n '
return import_pkgs, _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
# TODO: handle imports
_PyTorchScriptTemplate = '''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
{}
{}
'''
_PyTorchModelTemplate = '''
class {graph_name}(nn.Module):
def __init__(self):
super().__init__()
{nodes}
def forward(self, {inputs}):
{edges}
return {outputs}
'''
# pylint: skip-file
"""
FIXME
This file is inherited from last version.
I expect it can work with a few modifications to incorporate with the latest API, but it hasn't
been tested and I'm not sure.
"""
from ..graph_v2 import IllegalGraphError, Cell, Edge, Graph, Node
from ..operations_tf import Operation
from ..type_utils import *
def graph_to_tensorflow_script(graph: Graph) -> str:
graphs = [graph_to_tensorflow_model(name, cell) for name, cell in graph.cell_templates.items()]
return _TensorFlowScriptTemplate.format('\n\n'.join(graphs)).strip()
def _sort_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
if not edges:
return []
if all(edge.tail_idx is None for edge in edges):
return edges
if all(isinstance(edge.tail_idx, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_idx))
if [edge.tail_idx for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> str:
edges = _sort_incoming_edges(node)
inputs = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_idx, int)
if node.graph.input_names is not None:
inputs.append(node.graph.input_names[edge.head_idx])
else:
inputs.append('_inputs[{}]'.format(edge.head_idx))
else:
if edge.head_idx is None:
inputs.append('{}'.format(edge.head.name))
else:
inputs.append('{}[{}]'.format(edge.head.name, edge.head_idx))
return ', '.join(inputs)
def graph_to_tensorflow_model(graph_name: str, graph: Graph) -> str:
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
node_codes = []
for node in nodes:
if isinstance(node, Cell):
node_codes.append('self.{} = {}()'.format(node.name, node.template_name))
else:
node_codes.append('self.{} = {}'.format(node.name, cast(Operation, node.operation).to_tensorflow_init()))
edge_codes = []
for node in nodes:
inputs = _format_inputs(node)
edge_codes.append('{} = self.{}({})'.format(node.name, node.name, inputs))
output_code = _format_inputs(graph.output_node)
if not output_code:
output_code = 'None'
if graph.input_names is None:
input_code = '*_inputs'
else:
input_code = ', '.join(graph.input_names)
linebreak = '\n '
return _TensorFlowModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
_TensorFlowScriptTemplate = '''
import tensorflow as tf
import tensorflow.keras as K
import sdk.custom_ops_tf as CUSTOM
{}
'''
_TensorFlowModelTemplate = '''
class {graph_name}(K.Model):
def __init__(self):
super().__init__()
{nodes}
def call(self, {inputs}):
{edges}
return {outputs}
'''
\ No newline at end of file
# PyTorch Graph Converter
## Namespace for PyTorch Graph
We should have a concrete rule for specifying nodes in graph with namespace.
Each node has a name, either specified or generated. The nodes in the same hierarchy cannot have the same name.
* The name of module node natively follows this rule, because we use variable name for instantiated modules like what PyTorch graph does.
* For the nodes created in `forward` function, we use a global sequence number.
### Namespace for mutated (new) nodes
TBD
## Graph Simplification
TBD
## Node Types
We define concrete type string for each node type.
## Module's Input Arguments
We use wrapper to obtain the input arguments of modules. Users need to use our wrapped "nn" and wrapped "Module".
## Control Flow
### for loop
Currently, we only support `ModuleList` (`ModuleDict`) based for loop, which is automatically unfolded by TorchScript. That is to say, we do not support loop in TorchScript for now.
### if/else
For now, we only deal with the case that the condition is constant or attribute. In this case, only one branch is kept during generating the graph.
\ No newline at end of file
from .graph_gen import convert_to_graph
import logging
import re
import torch
from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell
from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName
from .utils import _convert_name, build_full_name
_logger = logging.getLogger(__name__)
global_seq = 0
global_graph_id = 0
modules_arg = None
def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
"""
Parameters
----------
ir_graph : Graph
node : torch._C.Node
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
output_remap : Dict
ignore_first : bool
if it is true, skip the first input
"""
is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1
new_node_input_idx = 0
for _input in node.inputs():
if ignore_first:
ignore_first = False
continue
# handle source node
if _input in graph_inputs:
idx = graph_inputs.index(_input)
src_node = ir_graph.input_node
src_node_idx = idx
elif _input in output_remap:
assert output_remap[_input].kind() == 'aten::append'
predecessor_node = output_remap[_input]
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
src_node_idx = None
src_node = node_index[predecessor_node]
assert isinstance(src_node, Node)
else:
predecessor_node = _input.node()
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs = [_output for _output in predecessor_node.outputs()]
if len(predecessor_outputs) == 1:
idx = None
else:
idx = predecessor_outputs.index(_input)
ir_predecessor_node = node_index[predecessor_node]
src_node_idx = idx
assert isinstance(ir_predecessor_node, Node)
src_node = ir_predecessor_node
# handle destination node
dst_node = new_node
if is_single_input:
dst_node_idx = None
else:
dst_node_idx = new_node_input_idx
# create edge
ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx))
new_node_input_idx += 1
def create_prim_constant_node(ir_graph, node, module_name):
global global_seq
attrs = {}
if node.outputsAt(0).toIValue() is not None:
attrs = {'value': node.outputsAt(0).toIValue()}
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, global_seq),
node.kind(), attrs)
return new_node
def handle_prim_attr_node(node):
assert node.hasAttribute('name')
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
return node.kind(), attrs
def _remove_mangle(module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def remove_unconnected_nodes(ir_graph, targeted_type=None):
"""
Parameters
----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout = set()
for edge in ir_graph.edges:
if edge.head.id not in node_fanout:
node_fanout.add(edge.head.id)
to_removes = []
for hidden_node in ir_graph.hidden_nodes:
if hidden_node.id not in node_fanout:
assert isinstance(hidden_node, Node)
if targeted_type is None:
to_removes.append(hidden_node)
elif hidden_node.operation.type == targeted_type:
to_removes.append(hidden_node)
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
Returns
-------
dict
the mapping from graph node to our graph ir node
"""
# handle inputs
graph_inputs = []
for _input in sm_graph.inputs():
if _input.debugName() == 'self':
assert _input.unique() == 0
continue
graph_inputs.append(_input)
# TODO: add scope name
ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap = {}
def handle_if_condition(cond_tensor):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
generate the expression using recursive calls
NOTE: do not support dynamic graph
"""
def _generate_expr(tensor):
if tensor.node().kind() == 'prim::GetAttr':
return f'({getattr(module, tensor.node().s("name"))})'
elif tensor.node().kind() == 'aten::__getitem__':
t = _generate_expr(tensor.node().inputsAt(0))
idx = _generate_expr(tensor.node().inputsAt(1))
return f'({t}[{idx}])'
elif tensor.node().kind() == 'prim::Constant':
return f'{tensor.toIValue()}'
elif tensor.node().kind() == 'aten::eq':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} == {right})'
else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition')
expr = _generate_expr(cond_tensor)
return eval(expr)
def handle_if_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
# only deal with input of prim::If is constant or attribute for now
# will support constant expression in future
inputs = [i for i in node.inputs()]
assert len(inputs) == 1
cond = handle_if_condition(inputs[0])
chosen_block = 0 if cond else 1
blocks = [block for block in node.blocks()]
assert len(blocks) == 2
last_block_node = None
for node in blocks[chosen_block].nodes():
last_block_node = handle_single_node(node)
return last_block_node
def handle_single_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
global global_seq
if node.kind() == 'prim::CallMethod':
# get and handle the first input, which should be an nn.Module
assert node.hasAttribute('name')
if node.s('name') == 'forward':
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str = _remove_mangle(node.inputsAt(0).type().str())
submodule = node.inputsAt(0).node()
assert submodule.kind() == 'prim::GetAttr'
assert submodule.hasAttribute('name')
submodule_name = submodule.s('name')
if submodule.inputsAt(0).debugName() == 'self':
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, ir_model)
else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if submodule.inputsAt(0).type().name() == 'ModuleList':
# handle ModuleList
predecessor = submodule.inputsAt(0).node()
assert predecessor.kind() == 'prim::GetAttr'
assert predecessor.hasAttribute('name')
assert predecessor.inputsAt(0).debugName() == 'self'
predecessor_name = predecessor.s('name')
# FIXME: exchange
submodule_full_name = build_full_name(module_name, [submodule_name, predecessor_name])
predecessor_obj = getattr(module, predecessor_name)
submodule_obj = getattr(predecessor_obj, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[predecessor_name]._modules[submodule_name],
submodule_obj, submodule_full_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
# TODO: match subgraph with maintained graphs
# build cell
if subgraph is None:
# if we do not parse this module's graph, we create Node for this module
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, (LayerChoice, InputChoice)):
subcell.update_label(sub_m_attrs['label'])
else:
# Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell)
node_index[node] = subcell
# connect the cell into graph
_add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
else:
raise RuntimeError('unsupported CallMethod {}'.format(node.s('name')))
elif node.kind() == 'prim::CallFunction':
func_type_str = _remove_mangle(node.inputsAt(0).type().str())
func = node.inputsAt(0).node()
assert func.kind() == 'prim::Constant'
assert func.hasAttribute('name')
func_name = func.s('name')
# create node for func
global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, global_seq),
'{}.{}'.format(func_type_str, func_name))
node_index[node] = func_node
_add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant':
new_node = create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node
elif node.kind() == 'prim::ListConstruct':
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, global_seq), node.kind())
node_index[node] = new_node
_add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'aten::append':
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = handle_prim_attr_node(node)
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
# last_block_node is None means no node in the branch block
node_index[node] = last_block_node
elif node.kind() == 'prim::Loop':
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise RuntimeError('Loop has not been supported yet!')
else:
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
return node_index[node]
for node in sm_graph.nodes():
handle_single_node(node)
return node_index
def merge_aten_slices(ir_graph):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes = []
has_slice_node = False
for node in ir_graph.hidden_nodes:
if node.operation.type == 'aten::slice':
has_slice_node = True
for pred in node.predecessors:
if pred.operation.type not in ['aten::slice', 'prim::Constant']:
head_slice_nodes.append(node)
break
if has_slice_node:
assert head_slice_nodes
for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
if len(head_node.incoming_edges) == 4:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
slot += 5
node = head_node
while len(node.successors) == 1 and node.successors[0].operation.type == 'aten::slice':
suc_node = node.successors[0]
assert len(suc_node.incoming_edges) == 5
for edge in suc_node.incoming_edges:
if edge.tail_slot == 0:
edge.remove()
else:
edge.tail = new_slice_node
edge.tail_slot = slot + edge.tail_slot - 1
slot += 4
ir_graph.hidden_nodes.remove(node)
node = suc_node
for edge in node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node)
def refine_graph(ir_graph):
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
merge_aten_slices(ir_graph)
def _handle_layerchoice(module):
global modules_arg
m_attrs = {}
candidates = module.candidate_ops
choices = []
for cand in candidates:
assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand))
assert isinstance(modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
choices.append({'type': cand_type, 'parameters': modules_arg[id(cand)]})
m_attrs[f'choices'] = choices
m_attrs['label'] = module.label
return m_attrs
def _handle_inputchoice(module):
m_attrs = {}
m_attrs['n_chosen'] = module.n_chosen
m_attrs['reduction'] = module.reduction
m_attrs['label'] = module.label
return m_attrs
def convert_module(script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
global global_graph_id
global modules_arg
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
if original_type_name == OpTypeName.LayerChoice:
m_attrs = _handle_layerchoice(module)
return None, m_attrs
if original_type_name == OpTypeName.InputChoice:
m_attrs = _handle_inputchoice(module)
return None, m_attrs
if original_type_name == OpTypeName.Placeholder:
m_attrs = modules_arg[id(module)]
return None, m_attrs
if original_type_name in torch.nn.__dict__ and original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
assert id(module) in modules_arg, f'{original_type_name} arguments are not recorded'
m_attrs = modules_arg[id(module)]
return None, m_attrs
# handle TorchScript graph
sm_graph = script_module.graph
global_graph_id += 1
ir_graph = Graph(model=ir_model, graph_id=global_graph_id, name=module_name, _internal=True)
# handle graph nodes
node_index = handle_graph_nodes(script_module, sm_graph, module,
module_name, ir_model, ir_graph)
# handle graph outputs
for _output in sm_graph.outputs():
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
refine_graph(ir_graph)
ir_graph._register()
if id(module) not in modules_arg:
raise RuntimeError(f'{original_type_name} arguments are not recorded, \
you might have forgotten to decorate this class with @register_module()')
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return ir_graph, {}
def convert_to_graph(script_module, module, recorded_modules_arg):
"""
Convert module to our graph ir, i.e., build a ```Model``` type
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module obtained with torch.jit.script
module : nn.Module
the targeted module instance
recorded_modules_arg : dict
the recorded args of each module in the module
Returns
Model
the constructed IR model
"""
global modules_arg
modules_arg = recorded_modules_arg
model = Model(_internal=True)
module_name = '_model'
convert_module(script_module, module, module_name, model)
return model
from enum import Enum
MODULE_EXCEPT_LIST = ['Sequential']
class OpTypeName(str, Enum):
"""
op type to its type name str
"""
Attr = 'Attr'
Constant = 'Constant'
ListConstruct = 'ListConstruct'
LayerChoice = 'LayerChoice'
InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
'aten::relu': 'Relu',
'aten::add': 'Add',
'aten::__getitem__': 'getitem',
'aten::append': 'Append',
'aten::len': 'Len',
'aten::slice': 'Slice',
'aten::cat': 'Cat',
'aten::size': 'Size',
'aten::view': 'View',
'aten::eq': 'Eq',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
BasicOpsTF = {}
def build_full_name(prefix, name, seq=None):
if isinstance(name, list):
name = '__'.join(name)
if seq is None:
return '{}__{}'.format(prefix, name)
else:
return '{}__{}{}'.format(prefix, name, str(seq))
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
import graphviz
def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items():
if name == '_training_config':
continue
with vgraph.subgraph(name='cluster'+name) as subgraph:
subgraph.attr(color='blue')
cell_node = {}
ioput = {'_inputs': '{}-{}'.format(name, '_'.join(graph['inputs'])),
'_outputs': '{}-{}'.format(name, '_'.join(graph['outputs']))}
subgraph.node(ioput['_inputs'])
subgraph.node(ioput['_outputs'])
for node_name, node_value in graph['nodes'].items():
value = node_value['operation']
if value['type'] == '_cell':
cell_input_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['inputs']))
cell_output_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['outputs']))
cell_node[node_name] = (cell_input_name, cell_output_name)
print('cell: ', node_name, cell_input_name, cell_output_name)
else:
subgraph.node(node_name)
for edge in graph['edges']:
src = edge['head'][0]
if src == '_inputs':
src = ioput['_inputs']
elif src in cell_node:
src = cell_node[src][1]
dst = edge['tail'][0]
if dst == '_outputs':
dst = ioput['_outputs']
elif dst in cell_node:
dst = cell_node[dst][0]
subgraph.edge(src, dst)
def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
convert_to_visualize(graph_ir, vgraph)
vgraph.render()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment