Unverified Commit 468917ca authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Merge pull request #3155 from microsoft/dev-retiarii

[Do NOT Squash] Merge retiarii dev branch to master
parents f8424a9f d5a551c8
......@@ -98,3 +98,6 @@ venv.bak/
# VSCode
.vscode
.vs
.history
generated/
test/ut/retiarii/_debug_graph_data.json
......@@ -11,9 +11,9 @@ import torch.nn as nn
import datasets
from model import CNN
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.algorithms.nas.pytorch.darts import DartsTrainer
from utils import accuracy
logger = logging.getLogger('nni')
if __name__ == "__main__":
......@@ -25,6 +25,7 @@ if __name__ == "__main__":
parser.add_argument("--channels", default=16, type=int)
parser.add_argument("--unrolled", default=False, action="store_true")
parser.add_argument("--visualization", default=False, action="store_true")
parser.add_argument("--v1", default=False, action="store_true")
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
......@@ -35,17 +36,35 @@ if __name__ == "__main__":
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
trainer = DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optim,
num_epochs=args.epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
unrolled=args.unrolled,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
if args.visualization:
trainer.enable_visualization()
trainer.train()
if args.v1:
from nni.algorithms.nas.pytorch.darts import DartsTrainer
trainer = DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optim,
num_epochs=args.epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
unrolled=args.unrolled,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
if args.visualization:
trainer.enable_visualization()
trainer.train()
else:
from nni.retiarii.trainer.pytorch import DartsTrainer
trainer = DartsTrainer(
model=model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optim,
num_epochs=args.epochs,
dataset=dataset_train,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
unrolled=args.unrolled
)
trainer.fit()
print('Final architecture:', trainer.export())
......@@ -48,9 +48,15 @@ class Cell(nn.Module):
], key=cell_name + "_op")
def forward(self, prev_layers):
chosen_input, chosen_mask = self.input_choice(prev_layers)
cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask
from nni.retiarii.trainer.pytorch.random import PathSamplingInputChoice
out = self.input_choice(prev_layers)
if isinstance(self.input_choice, PathSamplingInputChoice):
# Retiarii pattern
return out, self.input_choice.mask
else:
chosen_input, chosen_mask = out
cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask
class Node(mutables.MutableScope):
......@@ -71,7 +77,7 @@ class Calibration(nn.Module):
self.process = None
if in_channels != out_channels:
self.process = StdConv(in_channels, out_channels)
def forward(self, x):
if self.process is None:
return x
......@@ -83,7 +89,7 @@ class ReductionLayer(nn.Module):
super().__init__()
self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False)
self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False)
def forward(self, pprev, prev):
return self.reduce0(pprev), self.reduce1(prev)
......@@ -109,7 +115,7 @@ class ENASLayer(nn.Module):
nn.init.kaiming_normal_(self.final_conv_w)
def forward(self, pprev, prev):
pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev)
pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev)
prev_nodes_out = [pprev_, prev_]
nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
......
......@@ -26,17 +26,22 @@ if __name__ == "__main__":
parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
parser.add_argument("--visualization", default=False, action="store_true")
parser.add_argument("--v1", default=False, action="store_true")
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
mutator = None
ctrl_kwargs = {}
if args.search_for == "macro":
model = GeneralNetwork()
num_epochs = args.epochs or 310
mutator = None
elif args.search_for == "micro":
model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True)
model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=False)
num_epochs = args.epochs or 150
mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
if args.v1:
mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
else:
ctrl_kwargs = {"tanh_constant": 1.1}
else:
raise AssertionError
......@@ -44,18 +49,32 @@ if __name__ == "__main__":
optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)
trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optimizer,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
batch_size=args.batch_size,
num_epochs=num_epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency,
mutator=mutator)
if args.visualization:
trainer.enable_visualization()
trainer.train()
if args.v1:
trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optimizer,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
batch_size=args.batch_size,
num_epochs=num_epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency,
mutator=mutator)
if args.visualization:
trainer.enable_visualization()
trainer.train()
else:
from nni.retiarii.trainer.pytorch.enas import EnasTrainer
trainer = EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optimizer,
batch_size=args.batch_size,
num_epochs=num_epochs,
dataset=dataset_train,
log_frequency=args.log_frequency,
ctrl_kwargs=ctrl_kwargs)
trainer.fit()
......@@ -6,7 +6,7 @@ import torchvision
import torchvision.transforms as transforms
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
from nni.nas.pytorch.darts import DartsTrainer
from nni.algorithms.nas.pytorch.darts import DartsTrainer
class Net(nn.Module):
......
import logging
import os
import sys
import logging
from argparse import ArgumentParser
import torch
import datasets
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from torchvision import transforms
from putils import get_parameters
import datasets
from model import SearchMobileNet
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from putils import LabelSmoothingLoss, accuracy, get_parameters
from retrain import Retrain
logger = logging.getLogger('nni_proxylessnas')
......@@ -30,7 +33,7 @@ if __name__ == "__main__":
parser.add_argument("--resize_scale", default=0.08, type=float)
parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None'])
# configurations for training mode
parser.add_argument("--train_mode", default='search', type=str, choices=['search', 'retrain'])
parser.add_argument("--train_mode", default='search', type=str, choices=['search_v1', 'search', 'retrain'])
# configurations for search
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
......@@ -80,6 +83,26 @@ if __name__ == "__main__":
optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)
if args.train_mode == 'search':
from nni.retiarii.trainer.pytorch import ProxylessTrainer
from torchvision.datasets import ImageNet
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = ImageNet(args.data_path, transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
trainer = ProxylessTrainer(model,
loss=LabelSmoothingLoss(),
dataset=dataset,
optimizer=optimizer,
metrics=lambda output, target: accuracy(output, target, topk=(1, 5,)),
num_epochs=120,
log_frequency=10)
trainer.fit()
print('Final architecture:', trainer.export())
elif args.train_mode == 'search_v1':
# this is architecture search
logger.info('Creating ProxylessNasTrainer...')
trainer = ProxylessNasTrainer(model,
......
......@@ -58,11 +58,9 @@ class SearchMobileNet(nn.Module):
# if it is not the first one
op_candidates += [ops.OPS['Zero'](input_channel, width, stride)]
conv_op = nas.mutables.LayerChoice(op_candidates,
return_mask=True,
key="s{}_c{}".format(stage_cnt, i))
else:
conv_op = nas.mutables.LayerChoice(op_candidates,
return_mask=True,
key="s{}_c{}".format(stage_cnt, i))
# shortcut
if stride == 1 and input_channel == width:
......
......@@ -39,19 +39,13 @@ class MobileInvertedResidualBlock(nn.Module):
self.op_candidates_list = op_candidates_list
def forward(self, x):
out, idx = self.mobile_inverted_conv(x)
# TODO: unify idx format
if not isinstance(idx, int):
idx = (idx == 1).nonzero()
if self.op_candidates_list[idx].is_zero_layer():
res = x
elif self.shortcut is None:
res = out
else:
conv_x = out
skip_x = self.shortcut(x)
res = skip_x + conv_x
return res
out = self.mobile_inverted_conv(x)
if torch.sum(torch.abs(out)).item() == 0 and x.size() == out.size():
# is zero layer
return x
if self.shortcut is None:
return out
return out + self.shortcut(x)
class ShuffleLayer(nn.Module):
......
import torch
import torch.nn as nn
def get_parameters(model, keys=None, mode='include'):
if keys is None:
for name, param in model.named_parameters():
......@@ -36,6 +38,7 @@ def get_same_padding(kernel_size):
assert kernel_size % 2 > 0, 'kernel size should be odd number'
return kernel_size // 2
def build_activation(act_func, inplace=True):
if act_func == 'relu':
return nn.ReLU(inplace=inplace)
......@@ -65,3 +68,40 @@ def make_divisible(v, divisor, min_val=None):
if new_v < 0.9 * v:
new_v += divisor
return new_v
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
class LabelSmoothingLoss(nn.Module):
def __init__(self, smoothing=0.1, dim=-1):
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.dim = dim
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
num_classes = pred.size(self.dim)
with torch.no_grad():
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (num_classes - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
......@@ -109,7 +109,13 @@ class MutableScope(Mutable):
def __init__(self, key):
super().__init__(key=key)
def _check_built(self):
return True # bypass the test because it's deprecated
def __call__(self, *args, **kwargs):
if not hasattr(self, 'mutator'):
return super().__call__(*args, **kwargs)
warnings.warn("`MutableScope` is deprecated in Retiarii.", DeprecationWarning)
try:
self._check_built()
self.mutator.enter_mutable_scope(self)
......
from .operation import Operation
from .graph import *
from .execution import *
from .mutator import *
from .utils import register_module
\ No newline at end of file
from .pytorch import model_to_pytorch_script
import logging
from typing import List
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
_logger = logging.getLogger(__name__)
def model_to_pytorch_script(model: Model, placement=None) -> str:
graphs = []
total_pkgs = set()
for name, cell in model.graphs.items():
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: %s', str(edges))
if not edges:
return []
_logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_slot))
if [edge.tail_slot for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> List[str]:
edges = _sorted_incoming_edges(node)
inputs = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int)
if edge.head.operation.io_names is not None:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs.append(edge.head.operation.io_names[edge.head_slot])
else:
# when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot))
else:
if edge.head_slot is None:
# when the input comes from a single-output operator
inputs.append('{}'.format(edge.head.name))
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return inputs
def _remove_prefix(names, graph_name):
"""
variables name (full name space) is too long,
shorten the name by removing the prefix ```graph_name```
"""
if isinstance(names, list):
converted_names = []
for name in names:
if name.startswith(graph_name):
converted_names.append(name[len(graph_name):])
else:
converted_names.append(name)
return converted_names
else:
return names[len(graph_name):] if names.startswith(graph_name) else names
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str:
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
for node in nodes:
if node.operation:
pkg_name = node.operation.get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name))
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
node_codes.append(f"{node_code}.to('{placement[node].device}')")
else:
node_codes.append(node_code)
if graph.input_node.operation.io_names is None:
input_code = '*_inputs'
else:
for name in graph.input_node.operation.io_names:
assert not name.startswith(graph_name)
input_code = ', '.join(graph.input_node.operation.io_names)
edge_codes = []
sorted_nodes = graph.topo_sort()
for node in sorted_nodes:
if node.operation:
inputs = _format_inputs(node)
inputs = _remove_prefix(inputs, graph_name)
node_name = _remove_prefix(node.name, graph_name)
edge_codes.append(node.operation.to_forward_code(node_name, node_name, inputs))
output_names = _format_inputs(graph.output_node)
output_names = _remove_prefix(output_names, graph_name)
if not output_names:
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
output_code = ', '.join(output_names)
linebreak = '\n '
return import_pkgs, _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
# TODO: handle imports
_PyTorchScriptTemplate = '''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
{}
{}
'''
_PyTorchModelTemplate = '''
class {graph_name}(nn.Module):
def __init__(self):
super().__init__()
{nodes}
def forward(self, {inputs}):
{edges}
return {outputs}
'''
# pylint: skip-file
"""
FIXME
This file is inherited from last version.
I expect it can work with a few modifications to incorporate with the latest API, but it hasn't
been tested and I'm not sure.
"""
from ..graph_v2 import IllegalGraphError, Cell, Edge, Graph, Node
from ..operations_tf import Operation
from ..type_utils import *
def graph_to_tensorflow_script(graph: Graph) -> str:
graphs = [graph_to_tensorflow_model(name, cell) for name, cell in graph.cell_templates.items()]
return _TensorFlowScriptTemplate.format('\n\n'.join(graphs)).strip()
def _sort_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
if not edges:
return []
if all(edge.tail_idx is None for edge in edges):
return edges
if all(isinstance(edge.tail_idx, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_idx))
if [edge.tail_idx for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> str:
edges = _sort_incoming_edges(node)
inputs = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_idx, int)
if node.graph.input_names is not None:
inputs.append(node.graph.input_names[edge.head_idx])
else:
inputs.append('_inputs[{}]'.format(edge.head_idx))
else:
if edge.head_idx is None:
inputs.append('{}'.format(edge.head.name))
else:
inputs.append('{}[{}]'.format(edge.head.name, edge.head_idx))
return ', '.join(inputs)
def graph_to_tensorflow_model(graph_name: str, graph: Graph) -> str:
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
node_codes = []
for node in nodes:
if isinstance(node, Cell):
node_codes.append('self.{} = {}()'.format(node.name, node.template_name))
else:
node_codes.append('self.{} = {}'.format(node.name, cast(Operation, node.operation).to_tensorflow_init()))
edge_codes = []
for node in nodes:
inputs = _format_inputs(node)
edge_codes.append('{} = self.{}({})'.format(node.name, node.name, inputs))
output_code = _format_inputs(graph.output_node)
if not output_code:
output_code = 'None'
if graph.input_names is None:
input_code = '*_inputs'
else:
input_code = ', '.join(graph.input_names)
linebreak = '\n '
return _TensorFlowModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
_TensorFlowScriptTemplate = '''
import tensorflow as tf
import tensorflow.keras as K
import sdk.custom_ops_tf as CUSTOM
{}
'''
_TensorFlowModelTemplate = '''
class {graph_name}(K.Model):
def __init__(self):
super().__init__()
{nodes}
def call(self, {inputs}):
{edges}
return {outputs}
'''
\ No newline at end of file
# PyTorch Graph Converter
## Namespace for PyTorch Graph
We should have a concrete rule for specifying nodes in graph with namespace.
Each node has a name, either specified or generated. The nodes in the same hierarchy cannot have the same name.
* The name of module node natively follows this rule, because we use variable name for instantiated modules like what PyTorch graph does.
* For the nodes created in `forward` function, we use a global sequence number.
### Namespace for mutated (new) nodes
TBD
## Graph Simplification
TBD
## Node Types
We define concrete type string for each node type.
## Module's Input Arguments
We use wrapper to obtain the input arguments of modules. Users need to use our wrapped "nn" and wrapped "Module".
## Control Flow
### for loop
Currently, we only support `ModuleList` (`ModuleDict`) based for loop, which is automatically unfolded by TorchScript. That is to say, we do not support loop in TorchScript for now.
### if/else
For now, we only deal with the case that the condition is constant or attribute. In this case, only one branch is kept during generating the graph.
\ No newline at end of file
from .graph_gen import convert_to_graph
import logging
import re
import torch
from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell
from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName
from .utils import _convert_name, build_full_name
_logger = logging.getLogger(__name__)
global_seq = 0
global_graph_id = 0
modules_arg = None
def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
"""
Parameters
----------
ir_graph : Graph
node : torch._C.Node
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
output_remap : Dict
ignore_first : bool
if it is true, skip the first input
"""
is_single_input = (len([_input for _input in node.inputs()]) - (1 if ignore_first else 0)) == 1
new_node_input_idx = 0
for _input in node.inputs():
if ignore_first:
ignore_first = False
continue
# handle source node
if _input in graph_inputs:
idx = graph_inputs.index(_input)
src_node = ir_graph.input_node
src_node_idx = idx
elif _input in output_remap:
assert output_remap[_input].kind() == 'aten::append'
predecessor_node = output_remap[_input]
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
src_node_idx = None
src_node = node_index[predecessor_node]
assert isinstance(src_node, Node)
else:
predecessor_node = _input.node()
assert predecessor_node in node_index, 'predecessor node: {}'.format(predecessor_node)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs = [_output for _output in predecessor_node.outputs()]
if len(predecessor_outputs) == 1:
idx = None
else:
idx = predecessor_outputs.index(_input)
ir_predecessor_node = node_index[predecessor_node]
src_node_idx = idx
assert isinstance(ir_predecessor_node, Node)
src_node = ir_predecessor_node
# handle destination node
dst_node = new_node
if is_single_input:
dst_node_idx = None
else:
dst_node_idx = new_node_input_idx
# create edge
ir_graph.add_edge(head=(src_node, src_node_idx), tail=(dst_node, dst_node_idx))
new_node_input_idx += 1
def create_prim_constant_node(ir_graph, node, module_name):
global global_seq
attrs = {}
if node.outputsAt(0).toIValue() is not None:
attrs = {'value': node.outputsAt(0).toIValue()}
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Constant, global_seq),
node.kind(), attrs)
return new_node
def handle_prim_attr_node(node):
assert node.hasAttribute('name')
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
return node.kind(), attrs
def _remove_mangle(module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def remove_unconnected_nodes(ir_graph, targeted_type=None):
"""
Parameters
----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout = set()
for edge in ir_graph.edges:
if edge.head.id not in node_fanout:
node_fanout.add(edge.head.id)
to_removes = []
for hidden_node in ir_graph.hidden_nodes:
if hidden_node.id not in node_fanout:
assert isinstance(hidden_node, Node)
if targeted_type is None:
to_removes.append(hidden_node)
elif hidden_node.operation.type == targeted_type:
to_removes.append(hidden_node)
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
Returns
-------
dict
the mapping from graph node to our graph ir node
"""
# handle inputs
graph_inputs = []
for _input in sm_graph.inputs():
if _input.debugName() == 'self':
assert _input.unique() == 0
continue
graph_inputs.append(_input)
# TODO: add scope name
ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap = {}
def handle_if_condition(cond_tensor):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
generate the expression using recursive calls
NOTE: do not support dynamic graph
"""
def _generate_expr(tensor):
if tensor.node().kind() == 'prim::GetAttr':
return f'({getattr(module, tensor.node().s("name"))})'
elif tensor.node().kind() == 'aten::__getitem__':
t = _generate_expr(tensor.node().inputsAt(0))
idx = _generate_expr(tensor.node().inputsAt(1))
return f'({t}[{idx}])'
elif tensor.node().kind() == 'prim::Constant':
return f'{tensor.toIValue()}'
elif tensor.node().kind() == 'aten::eq':
left = _generate_expr(tensor.node().inputsAt(0))
right = _generate_expr(tensor.node().inputsAt(1))
return f'({left} == {right})'
else:
raise RuntimeError(f'Unsupported op type {tensor.node().kind()} in if condition')
expr = _generate_expr(cond_tensor)
return eval(expr)
def handle_if_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
# only deal with input of prim::If is constant or attribute for now
# will support constant expression in future
inputs = [i for i in node.inputs()]
assert len(inputs) == 1
cond = handle_if_condition(inputs[0])
chosen_block = 0 if cond else 1
blocks = [block for block in node.blocks()]
assert len(blocks) == 2
last_block_node = None
for node in blocks[chosen_block].nodes():
last_block_node = handle_single_node(node)
return last_block_node
def handle_single_node(node):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
global global_seq
if node.kind() == 'prim::CallMethod':
# get and handle the first input, which should be an nn.Module
assert node.hasAttribute('name')
if node.s('name') == 'forward':
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str = _remove_mangle(node.inputsAt(0).type().str())
submodule = node.inputsAt(0).node()
assert submodule.kind() == 'prim::GetAttr'
assert submodule.hasAttribute('name')
submodule_name = submodule.s('name')
if submodule.inputsAt(0).debugName() == 'self':
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name],
submodule_obj,
submodule_full_name, ir_model)
else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if submodule.inputsAt(0).type().name() == 'ModuleList':
# handle ModuleList
predecessor = submodule.inputsAt(0).node()
assert predecessor.kind() == 'prim::GetAttr'
assert predecessor.hasAttribute('name')
assert predecessor.inputsAt(0).debugName() == 'self'
predecessor_name = predecessor.s('name')
# FIXME: exchange
submodule_full_name = build_full_name(module_name, [submodule_name, predecessor_name])
predecessor_obj = getattr(module, predecessor_name)
submodule_obj = getattr(predecessor_obj, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[predecessor_name]._modules[submodule_name],
submodule_obj, submodule_full_name, ir_model)
else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
# TODO: match subgraph with maintained graphs
# build cell
if subgraph is None:
# if we do not parse this module's graph, we create Node for this module
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, (LayerChoice, InputChoice)):
subcell.update_label(sub_m_attrs['label'])
else:
# Graph already created, create Cell for it
new_cell = Cell(cell_name=submodule_full_name, parameters=sub_m_attrs)
subcell = ir_graph.add_node(submodule_full_name, new_cell)
node_index[node] = subcell
# connect the cell into graph
_add_edge(ir_graph, node, graph_inputs, node_index, subcell, output_remap, ignore_first=True)
else:
raise RuntimeError('unsupported CallMethod {}'.format(node.s('name')))
elif node.kind() == 'prim::CallFunction':
func_type_str = _remove_mangle(node.inputsAt(0).type().str())
func = node.inputsAt(0).node()
assert func.kind() == 'prim::Constant'
assert func.hasAttribute('name')
func_name = func.s('name')
# create node for func
global_seq += 1
func_node = ir_graph.add_node(build_full_name(module_name, func_name, global_seq),
'{}.{}'.format(func_type_str, func_name))
node_index[node] = func_node
_add_edge(ir_graph, node, graph_inputs, node_index, func_node, output_remap, ignore_first=True)
elif node.kind() == 'prim::Constant':
new_node = create_prim_constant_node(ir_graph, node, module_name)
node_index[node] = new_node
elif node.kind() == 'prim::ListConstruct':
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.ListConstruct, global_seq), node.kind())
node_index[node] = new_node
_add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap)
elif node.kind() == 'aten::append':
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
output_remap[node.inputsAt(0)] = node
elif node.kind().startswith('aten::'):
# handle aten::XXX
global_seq += 1
aten_node = ir_graph.add_node(build_full_name(module_name, BasicOpsPT[node.kind()], global_seq), node.kind())
node_index[node] = aten_node
_add_edge(ir_graph, node, graph_inputs, node_index, aten_node, output_remap)
elif node.kind() == 'prim::GetAttr':
node_type, attrs = handle_prim_attr_node(node)
global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq),
node_type, attrs)
node_index[node] = new_node
elif node.kind() == 'prim::If':
last_block_node = handle_if_node(node)
# last_block_node is None means no node in the branch block
node_index[node] = last_block_node
elif node.kind() == 'prim::Loop':
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise RuntimeError('Loop has not been supported yet!')
else:
raise RuntimeError('Unsupported kind: {}'.format(node.kind()))
return node_index[node]
for node in sm_graph.nodes():
handle_single_node(node)
return node_index
def merge_aten_slices(ir_graph):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes = []
has_slice_node = False
for node in ir_graph.hidden_nodes:
if node.operation.type == 'aten::slice':
has_slice_node = True
for pred in node.predecessors:
if pred.operation.type not in ['aten::slice', 'prim::Constant']:
head_slice_nodes.append(node)
break
if has_slice_node:
assert head_slice_nodes
for head_node in head_slice_nodes:
slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
if len(head_node.incoming_edges) == 4:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert len(head_node.incoming_edges) == 5
for edge in head_node.incoming_edges:
edge.tail = new_slice_node
slot += 5
node = head_node
while len(node.successors) == 1 and node.successors[0].operation.type == 'aten::slice':
suc_node = node.successors[0]
assert len(suc_node.incoming_edges) == 5
for edge in suc_node.incoming_edges:
if edge.tail_slot == 0:
edge.remove()
else:
edge.tail = new_slice_node
edge.tail_slot = slot + edge.tail_slot - 1
slot += 4
ir_graph.hidden_nodes.remove(node)
node = suc_node
for edge in node.outgoing_edges:
edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node)
def refine_graph(ir_graph):
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
remove_unconnected_nodes(ir_graph, targeted_type='prim::Constant')
remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
merge_aten_slices(ir_graph)
def _handle_layerchoice(module):
global modules_arg
m_attrs = {}
candidates = module.candidate_ops
choices = []
for cand in candidates:
assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand))
assert isinstance(modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
choices.append({'type': cand_type, 'parameters': modules_arg[id(cand)]})
m_attrs[f'choices'] = choices
m_attrs['label'] = module.label
return m_attrs
def _handle_inputchoice(module):
m_attrs = {}
m_attrs['n_chosen'] = module.n_chosen
m_attrs['reduction'] = module.reduction
m_attrs['label'] = module.label
return m_attrs
def convert_module(script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
global global_graph_id
global modules_arg
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name = script_module.original_name
if original_type_name == OpTypeName.LayerChoice:
m_attrs = _handle_layerchoice(module)
return None, m_attrs
if original_type_name == OpTypeName.InputChoice:
m_attrs = _handle_inputchoice(module)
return None, m_attrs
if original_type_name == OpTypeName.Placeholder:
m_attrs = modules_arg[id(module)]
return None, m_attrs
if original_type_name in torch.nn.__dict__ and original_type_name not in MODULE_EXCEPT_LIST:
# this is a basic module from pytorch, no need to parse its graph
assert id(module) in modules_arg, f'{original_type_name} arguments are not recorded'
m_attrs = modules_arg[id(module)]
return None, m_attrs
# handle TorchScript graph
sm_graph = script_module.graph
global_graph_id += 1
ir_graph = Graph(model=ir_model, graph_id=global_graph_id, name=module_name, _internal=True)
# handle graph nodes
node_index = handle_graph_nodes(script_module, sm_graph, module,
module_name, ir_model, ir_graph)
# handle graph outputs
for _output in sm_graph.outputs():
ir_graph._add_output(_convert_name(_output.debugName()))
predecessor_node_outputs = [o for o in _output.node().outputs()]
if len(predecessor_node_outputs) == 1:
src_node_idx = None
else:
src_node_idx = predecessor_node_outputs.index(_output)
ir_graph.add_edge(head=(node_index[_output.node()], src_node_idx),
tail=(ir_graph.output_node, None))
refine_graph(ir_graph)
ir_graph._register()
if id(module) not in modules_arg:
raise RuntimeError(f'{original_type_name} arguments are not recorded, \
you might have forgotten to decorate this class with @register_module()')
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return ir_graph, {}
def convert_to_graph(script_module, module, recorded_modules_arg):
"""
Convert module to our graph ir, i.e., build a ```Model``` type
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module obtained with torch.jit.script
module : nn.Module
the targeted module instance
recorded_modules_arg : dict
the recorded args of each module in the module
Returns
Model
the constructed IR model
"""
global modules_arg
modules_arg = recorded_modules_arg
model = Model(_internal=True)
module_name = '_model'
convert_module(script_module, module, module_name, model)
return model
from enum import Enum
MODULE_EXCEPT_LIST = ['Sequential']
class OpTypeName(str, Enum):
"""
op type to its type name str
"""
Attr = 'Attr'
Constant = 'Constant'
ListConstruct = 'ListConstruct'
LayerChoice = 'LayerChoice'
InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
'aten::relu': 'Relu',
'aten::add': 'Add',
'aten::__getitem__': 'getitem',
'aten::append': 'Append',
'aten::len': 'Len',
'aten::slice': 'Slice',
'aten::cat': 'Cat',
'aten::size': 'Size',
'aten::view': 'View',
'aten::eq': 'Eq',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
BasicOpsTF = {}
def build_full_name(prefix, name, seq=None):
if isinstance(name, list):
name = '__'.join(name)
if seq is None:
return '{}__{}'.format(prefix, name)
else:
return '{}__{}{}'.format(prefix, name, str(seq))
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
import graphviz
def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items():
if name == '_training_config':
continue
with vgraph.subgraph(name='cluster'+name) as subgraph:
subgraph.attr(color='blue')
cell_node = {}
ioput = {'_inputs': '{}-{}'.format(name, '_'.join(graph['inputs'])),
'_outputs': '{}-{}'.format(name, '_'.join(graph['outputs']))}
subgraph.node(ioput['_inputs'])
subgraph.node(ioput['_outputs'])
for node_name, node_value in graph['nodes'].items():
value = node_value['operation']
if value['type'] == '_cell':
cell_input_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['inputs']))
cell_output_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['outputs']))
cell_node[node_name] = (cell_input_name, cell_output_name)
print('cell: ', node_name, cell_input_name, cell_output_name)
else:
subgraph.node(node_name)
for edge in graph['edges']:
src = edge['head'][0]
if src == '_inputs':
src = ioput['_inputs']
elif src in cell_node:
src = cell_node[src][1]
dst = edge['tail'][0]
if dst == '_outputs':
dst = ioput['_outputs']
elif dst in cell_node:
dst = cell_node[dst][0]
subgraph.edge(src, dst)
def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
convert_to_visualize(graph_ir, vgraph)
vgraph.render()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment