Unverified Commit 468917ca authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Merge pull request #3155 from microsoft/dev-retiarii

[Do NOT Squash] Merge retiarii dev branch to master
parents f8424a9f d5a551c8
...@@ -98,3 +98,6 @@ venv.bak/ ...@@ -98,3 +98,6 @@ venv.bak/
# VSCode # VSCode
.vscode .vscode
.vs .vs
.history
generated/
test/ut/retiarii/_debug_graph_data.json
...@@ -11,9 +11,9 @@ import torch.nn as nn ...@@ -11,9 +11,9 @@ import torch.nn as nn
import datasets import datasets
from model import CNN from model import CNN
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.algorithms.nas.pytorch.darts import DartsTrainer
from utils import accuracy from utils import accuracy
logger = logging.getLogger('nni') logger = logging.getLogger('nni')
if __name__ == "__main__": if __name__ == "__main__":
...@@ -25,6 +25,7 @@ if __name__ == "__main__": ...@@ -25,6 +25,7 @@ if __name__ == "__main__":
parser.add_argument("--channels", default=16, type=int) parser.add_argument("--channels", default=16, type=int)
parser.add_argument("--unrolled", default=False, action="store_true") parser.add_argument("--unrolled", default=False, action="store_true")
parser.add_argument("--visualization", default=False, action="store_true") parser.add_argument("--visualization", default=False, action="store_true")
parser.add_argument("--v1", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10") dataset_train, dataset_valid = datasets.get_dataset("cifar10")
...@@ -35,6 +36,8 @@ if __name__ == "__main__": ...@@ -35,6 +36,8 @@ if __name__ == "__main__":
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
if args.v1:
from nni.algorithms.nas.pytorch.darts import DartsTrainer
trainer = DartsTrainer(model, trainer = DartsTrainer(model,
loss=criterion, loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)), metrics=lambda output, target: accuracy(output, target, topk=(1,)),
...@@ -48,4 +51,20 @@ if __name__ == "__main__": ...@@ -48,4 +51,20 @@ if __name__ == "__main__":
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
if args.visualization: if args.visualization:
trainer.enable_visualization() trainer.enable_visualization()
trainer.train() trainer.train()
else:
from nni.retiarii.trainer.pytorch import DartsTrainer
trainer = DartsTrainer(
model=model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optim,
num_epochs=args.epochs,
dataset=dataset_train,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
unrolled=args.unrolled
)
trainer.fit()
print('Final architecture:', trainer.export())
...@@ -48,7 +48,13 @@ class Cell(nn.Module): ...@@ -48,7 +48,13 @@ class Cell(nn.Module):
], key=cell_name + "_op") ], key=cell_name + "_op")
def forward(self, prev_layers): def forward(self, prev_layers):
chosen_input, chosen_mask = self.input_choice(prev_layers) from nni.retiarii.trainer.pytorch.random import PathSamplingInputChoice
out = self.input_choice(prev_layers)
if isinstance(self.input_choice, PathSamplingInputChoice):
# Retiarii pattern
return out, self.input_choice.mask
else:
chosen_input, chosen_mask = out
cell_out = self.op_choice(chosen_input) cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask return cell_out, chosen_mask
......
...@@ -26,17 +26,22 @@ if __name__ == "__main__": ...@@ -26,17 +26,22 @@ if __name__ == "__main__":
parser.add_argument("--search-for", choices=["macro", "micro"], default="macro") parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
parser.add_argument("--visualization", default=False, action="store_true") parser.add_argument("--visualization", default=False, action="store_true")
parser.add_argument("--v1", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10") dataset_train, dataset_valid = datasets.get_dataset("cifar10")
mutator = None
ctrl_kwargs = {}
if args.search_for == "macro": if args.search_for == "macro":
model = GeneralNetwork() model = GeneralNetwork()
num_epochs = args.epochs or 310 num_epochs = args.epochs or 310
mutator = None
elif args.search_for == "micro": elif args.search_for == "micro":
model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=False)
num_epochs = args.epochs or 150 num_epochs = args.epochs or 150
if args.v1:
mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True) mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
else:
ctrl_kwargs = {"tanh_constant": 1.1}
else: else:
raise AssertionError raise AssertionError
...@@ -44,6 +49,7 @@ if __name__ == "__main__": ...@@ -44,6 +49,7 @@ if __name__ == "__main__":
optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)
if args.v1:
trainer = enas.EnasTrainer(model, trainer = enas.EnasTrainer(model,
loss=criterion, loss=criterion,
metrics=accuracy, metrics=accuracy,
...@@ -59,3 +65,16 @@ if __name__ == "__main__": ...@@ -59,3 +65,16 @@ if __name__ == "__main__":
if args.visualization: if args.visualization:
trainer.enable_visualization() trainer.enable_visualization()
trainer.train() trainer.train()
else:
from nni.retiarii.trainer.pytorch.enas import EnasTrainer
trainer = EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optimizer,
batch_size=args.batch_size,
num_epochs=num_epochs,
dataset=dataset_train,
log_frequency=args.log_frequency,
ctrl_kwargs=ctrl_kwargs)
trainer.fit()
...@@ -6,7 +6,7 @@ import torchvision ...@@ -6,7 +6,7 @@ import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
from nni.nas.pytorch.mutables import LayerChoice, InputChoice from nni.nas.pytorch.mutables import LayerChoice, InputChoice
from nni.nas.pytorch.darts import DartsTrainer from nni.algorithms.nas.pytorch.darts import DartsTrainer
class Net(nn.Module): class Net(nn.Module):
......
import logging
import os import os
import sys import sys
import logging
from argparse import ArgumentParser from argparse import ArgumentParser
import torch import torch
import datasets from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from torchvision import transforms
from putils import get_parameters import datasets
from model import SearchMobileNet from model import SearchMobileNet
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from putils import LabelSmoothingLoss, accuracy, get_parameters
from retrain import Retrain from retrain import Retrain
logger = logging.getLogger('nni_proxylessnas') logger = logging.getLogger('nni_proxylessnas')
...@@ -30,7 +33,7 @@ if __name__ == "__main__": ...@@ -30,7 +33,7 @@ if __name__ == "__main__":
parser.add_argument("--resize_scale", default=0.08, type=float) parser.add_argument("--resize_scale", default=0.08, type=float)
parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None']) parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None'])
# configurations for training mode # configurations for training mode
parser.add_argument("--train_mode", default='search', type=str, choices=['search', 'retrain']) parser.add_argument("--train_mode", default='search', type=str, choices=['search_v1', 'search', 'retrain'])
# configurations for search # configurations for search
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str) parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
parser.add_argument("--arch_path", default='./arch_path.pt', type=str) parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
...@@ -80,6 +83,26 @@ if __name__ == "__main__": ...@@ -80,6 +83,26 @@ if __name__ == "__main__":
optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5) optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)
if args.train_mode == 'search': if args.train_mode == 'search':
from nni.retiarii.trainer.pytorch import ProxylessTrainer
from torchvision.datasets import ImageNet
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = ImageNet(args.data_path, transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
trainer = ProxylessTrainer(model,
loss=LabelSmoothingLoss(),
dataset=dataset,
optimizer=optimizer,
metrics=lambda output, target: accuracy(output, target, topk=(1, 5,)),
num_epochs=120,
log_frequency=10)
trainer.fit()
print('Final architecture:', trainer.export())
elif args.train_mode == 'search_v1':
# this is architecture search # this is architecture search
logger.info('Creating ProxylessNasTrainer...') logger.info('Creating ProxylessNasTrainer...')
trainer = ProxylessNasTrainer(model, trainer = ProxylessNasTrainer(model,
......
...@@ -58,11 +58,9 @@ class SearchMobileNet(nn.Module): ...@@ -58,11 +58,9 @@ class SearchMobileNet(nn.Module):
# if it is not the first one # if it is not the first one
op_candidates += [ops.OPS['Zero'](input_channel, width, stride)] op_candidates += [ops.OPS['Zero'](input_channel, width, stride)]
conv_op = nas.mutables.LayerChoice(op_candidates, conv_op = nas.mutables.LayerChoice(op_candidates,
return_mask=True,
key="s{}_c{}".format(stage_cnt, i)) key="s{}_c{}".format(stage_cnt, i))
else: else:
conv_op = nas.mutables.LayerChoice(op_candidates, conv_op = nas.mutables.LayerChoice(op_candidates,
return_mask=True,
key="s{}_c{}".format(stage_cnt, i)) key="s{}_c{}".format(stage_cnt, i))
# shortcut # shortcut
if stride == 1 and input_channel == width: if stride == 1 and input_channel == width:
......
...@@ -39,19 +39,13 @@ class MobileInvertedResidualBlock(nn.Module): ...@@ -39,19 +39,13 @@ class MobileInvertedResidualBlock(nn.Module):
self.op_candidates_list = op_candidates_list self.op_candidates_list = op_candidates_list
def forward(self, x): def forward(self, x):
out, idx = self.mobile_inverted_conv(x) out = self.mobile_inverted_conv(x)
# TODO: unify idx format if torch.sum(torch.abs(out)).item() == 0 and x.size() == out.size():
if not isinstance(idx, int): # is zero layer
idx = (idx == 1).nonzero() return x
if self.op_candidates_list[idx].is_zero_layer(): if self.shortcut is None:
res = x return out
elif self.shortcut is None: return out + self.shortcut(x)
res = out
else:
conv_x = out
skip_x = self.shortcut(x)
res = skip_x + conv_x
return res
class ShuffleLayer(nn.Module): class ShuffleLayer(nn.Module):
......
import torch
import torch.nn as nn import torch.nn as nn
def get_parameters(model, keys=None, mode='include'): def get_parameters(model, keys=None, mode='include'):
if keys is None: if keys is None:
for name, param in model.named_parameters(): for name, param in model.named_parameters():
...@@ -36,6 +38,7 @@ def get_same_padding(kernel_size): ...@@ -36,6 +38,7 @@ def get_same_padding(kernel_size):
assert kernel_size % 2 > 0, 'kernel size should be odd number' assert kernel_size % 2 > 0, 'kernel size should be odd number'
return kernel_size // 2 return kernel_size // 2
def build_activation(act_func, inplace=True): def build_activation(act_func, inplace=True):
if act_func == 'relu': if act_func == 'relu':
return nn.ReLU(inplace=inplace) return nn.ReLU(inplace=inplace)
...@@ -65,3 +68,40 @@ def make_divisible(v, divisor, min_val=None): ...@@ -65,3 +68,40 @@ def make_divisible(v, divisor, min_val=None):
if new_v < 0.9 * v: if new_v < 0.9 * v:
new_v += divisor new_v += divisor
return new_v return new_v
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
class LabelSmoothingLoss(nn.Module):
def __init__(self, smoothing=0.1, dim=-1):
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.dim = dim
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
num_classes = pred.size(self.dim)
with torch.no_grad():
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (num_classes - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
...@@ -109,7 +109,13 @@ class MutableScope(Mutable): ...@@ -109,7 +109,13 @@ class MutableScope(Mutable):
def __init__(self, key): def __init__(self, key):
super().__init__(key=key) super().__init__(key=key)
def _check_built(self):
return True # bypass the test because it's deprecated
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if not hasattr(self, 'mutator'):
return super().__call__(*args, **kwargs)
warnings.warn("`MutableScope` is deprecated in Retiarii.", DeprecationWarning)
try: try:
self._check_built() self._check_built()
self.mutator.enter_mutable_scope(self) self.mutator.enter_mutable_scope(self)
......
from .operation import Operation
from .graph import *
from .execution import *
from .mutator import *
from .utils import register_module
\ No newline at end of file
from .pytorch import model_to_pytorch_script
import logging
from typing import List
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
_logger = logging.getLogger(__name__)
def model_to_pytorch_script(model: Model, placement=None) -> str:
graphs = []
total_pkgs = set()
for name, cell in model.graphs.items():
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: %s', str(edges))
if not edges:
return []
_logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_slot))
if [edge.tail_slot for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> List[str]:
edges = _sorted_incoming_edges(node)
inputs = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_slot, int)
if edge.head.operation.io_names is not None:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs.append(edge.head.operation.io_names[edge.head_slot])
else:
# when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot))
else:
if edge.head_slot is None:
# when the input comes from a single-output operator
inputs.append('{}'.format(edge.head.name))
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return inputs
def _remove_prefix(names, graph_name):
"""
variables name (full name space) is too long,
shorten the name by removing the prefix ```graph_name```
"""
if isinstance(names, list):
converted_names = []
for name in names:
if name.startswith(graph_name):
converted_names.append(name[len(graph_name):])
else:
converted_names.append(name)
return converted_names
else:
return names[len(graph_name):] if names.startswith(graph_name) else names
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str:
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
for node in nodes:
if node.operation:
pkg_name = node.operation.get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name))
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
node_codes.append(f"{node_code}.to('{placement[node].device}')")
else:
node_codes.append(node_code)
if graph.input_node.operation.io_names is None:
input_code = '*_inputs'
else:
for name in graph.input_node.operation.io_names:
assert not name.startswith(graph_name)
input_code = ', '.join(graph.input_node.operation.io_names)
edge_codes = []
sorted_nodes = graph.topo_sort()
for node in sorted_nodes:
if node.operation:
inputs = _format_inputs(node)
inputs = _remove_prefix(inputs, graph_name)
node_name = _remove_prefix(node.name, graph_name)
edge_codes.append(node.operation.to_forward_code(node_name, node_name, inputs))
output_names = _format_inputs(graph.output_node)
output_names = _remove_prefix(output_names, graph_name)
if not output_names:
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
output_code = ', '.join(output_names)
linebreak = '\n '
return import_pkgs, _PyTorchModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
# TODO: handle imports
_PyTorchScriptTemplate = '''
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
{}
{}
'''
_PyTorchModelTemplate = '''
class {graph_name}(nn.Module):
def __init__(self):
super().__init__()
{nodes}
def forward(self, {inputs}):
{edges}
return {outputs}
'''
# pylint: skip-file
"""
FIXME
This file is inherited from last version.
I expect it can work with a few modifications to incorporate with the latest API, but it hasn't
been tested and I'm not sure.
"""
from ..graph_v2 import IllegalGraphError, Cell, Edge, Graph, Node
from ..operations_tf import Operation
from ..type_utils import *
def graph_to_tensorflow_script(graph: Graph) -> str:
graphs = [graph_to_tensorflow_model(name, cell) for name, cell in graph.cell_templates.items()]
return _TensorFlowScriptTemplate.format('\n\n'.join(graphs)).strip()
def _sort_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
if not edges:
return []
if all(edge.tail_idx is None for edge in edges):
return edges
if all(isinstance(edge.tail_idx, int) for edge in edges):
edges = sorted(edges, key=(lambda edge: edge.tail_idx))
if [edge.tail_idx for edge in edges] == list(range(len(edges))):
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> str:
edges = _sort_incoming_edges(node)
inputs = []
for edge in edges:
if edge.head.name == '_inputs':
assert isinstance(edge.head_idx, int)
if node.graph.input_names is not None:
inputs.append(node.graph.input_names[edge.head_idx])
else:
inputs.append('_inputs[{}]'.format(edge.head_idx))
else:
if edge.head_idx is None:
inputs.append('{}'.format(edge.head.name))
else:
inputs.append('{}[{}]'.format(edge.head.name, edge.head_idx))
return ', '.join(inputs)
def graph_to_tensorflow_model(graph_name: str, graph: Graph) -> str:
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
node_codes = []
for node in nodes:
if isinstance(node, Cell):
node_codes.append('self.{} = {}()'.format(node.name, node.template_name))
else:
node_codes.append('self.{} = {}'.format(node.name, cast(Operation, node.operation).to_tensorflow_init()))
edge_codes = []
for node in nodes:
inputs = _format_inputs(node)
edge_codes.append('{} = self.{}({})'.format(node.name, node.name, inputs))
output_code = _format_inputs(graph.output_node)
if not output_code:
output_code = 'None'
if graph.input_names is None:
input_code = '*_inputs'
else:
input_code = ', '.join(graph.input_names)
linebreak = '\n '
return _TensorFlowModelTemplate.format(
graph_name=('Graph' if graph_name == '_graph' else graph_name),
inputs=input_code,
outputs=output_code,
nodes=linebreak.join(node_codes),
edges=linebreak.join(edge_codes)
)
_TensorFlowScriptTemplate = '''
import tensorflow as tf
import tensorflow.keras as K
import sdk.custom_ops_tf as CUSTOM
{}
'''
_TensorFlowModelTemplate = '''
class {graph_name}(K.Model):
def __init__(self):
super().__init__()
{nodes}
def call(self, {inputs}):
{edges}
return {outputs}
'''
\ No newline at end of file
# PyTorch Graph Converter
## Namespace for PyTorch Graph
We should have a concrete rule for specifying nodes in graph with namespace.
Each node has a name, either specified or generated. The nodes in the same hierarchy cannot have the same name.
* The name of module node natively follows this rule, because we use variable name for instantiated modules like what PyTorch graph does.
* For the nodes created in `forward` function, we use a global sequence number.
### Namespace for mutated (new) nodes
TBD
## Graph Simplification
TBD
## Node Types
We define concrete type string for each node type.
## Module's Input Arguments
We use wrapper to obtain the input arguments of modules. Users need to use our wrapped "nn" and wrapped "Module".
## Control Flow
### for loop
Currently, we only support `ModuleList` (`ModuleDict`) based for loop, which is automatically unfolded by TorchScript. That is to say, we do not support loop in TorchScript for now.
### if/else
For now, we only deal with the case that the condition is constant or attribute. In this case, only one branch is kept during generating the graph.
\ No newline at end of file
from .graph_gen import convert_to_graph
This diff is collapsed.
from enum import Enum
MODULE_EXCEPT_LIST = ['Sequential']
class OpTypeName(str, Enum):
"""
op type to its type name str
"""
Attr = 'Attr'
Constant = 'Constant'
ListConstruct = 'ListConstruct'
LayerChoice = 'LayerChoice'
InputChoice = 'InputChoice'
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
'aten::relu': 'Relu',
'aten::add': 'Add',
'aten::__getitem__': 'getitem',
'aten::append': 'Append',
'aten::len': 'Len',
'aten::slice': 'Slice',
'aten::cat': 'Cat',
'aten::size': 'Size',
'aten::view': 'View',
'aten::eq': 'Eq',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
BasicOpsTF = {}
def build_full_name(prefix, name, seq=None):
if isinstance(name, list):
name = '__'.join(name)
if seq is None:
return '{}__{}'.format(prefix, name)
else:
return '{}__{}{}'.format(prefix, name, str(seq))
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
import graphviz
def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items():
if name == '_training_config':
continue
with vgraph.subgraph(name='cluster'+name) as subgraph:
subgraph.attr(color='blue')
cell_node = {}
ioput = {'_inputs': '{}-{}'.format(name, '_'.join(graph['inputs'])),
'_outputs': '{}-{}'.format(name, '_'.join(graph['outputs']))}
subgraph.node(ioput['_inputs'])
subgraph.node(ioput['_outputs'])
for node_name, node_value in graph['nodes'].items():
value = node_value['operation']
if value['type'] == '_cell':
cell_input_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['inputs']))
cell_output_name = '{}-{}'.format(value['cell_name'], '_'.join(graph_ir[value['cell_name']]['outputs']))
cell_node[node_name] = (cell_input_name, cell_output_name)
print('cell: ', node_name, cell_input_name, cell_output_name)
else:
subgraph.node(node_name)
for edge in graph['edges']:
src = edge['head'][0]
if src == '_inputs':
src = ioput['_inputs']
elif src in cell_node:
src = cell_node[src][1]
dst = edge['tail'][0]
if dst == '_outputs':
dst = ioput['_outputs']
elif dst in cell_node:
dst = cell_node[dst][0]
subgraph.edge(src, dst)
def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
convert_to_visualize(graph_ir, vgraph)
vgraph.render()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment