Unverified Commit 5c861676 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Graph torch14 refactor (#2384)

parent ac238f01
# Azure hosted agents specification:
# https://docs.microsoft.com/en-us/azure/devops/pipelines/agents/hosted?view=azure-devops
jobs: jobs:
- job: 'basic_test_pr_ubuntu' - job: 'ubuntu_1804_python36'
pool: pool:
vmImage: 'Ubuntu 16.04' vmImage: 'Ubuntu 18.04'
strategy:
matrix:
Python36:
PYTHON_VERSION: '3.6'
steps: steps:
- script: | - script: |
...@@ -26,9 +25,8 @@ jobs: ...@@ -26,9 +25,8 @@ jobs:
yarn eslint yarn eslint
displayName: 'Run eslint' displayName: 'Run eslint'
- script: | - script: |
python3 -m pip install torch==1.2.0 --user python3 -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install torchvision==0.4.0 --user python3 -m pip install tensorflow==1.15.2 --user
python3 -m pip install tensorflow==1.13.1 --user
python3 -m pip install keras==2.1.6 --user python3 -m pip install keras==2.1.6 --user
python3 -m pip install gym onnx --user python3 -m pip install gym onnx --user
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 --user python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 --user
...@@ -63,13 +61,41 @@ jobs: ...@@ -63,13 +61,41 @@ jobs:
sphinx-build -M html . _build -W sphinx-build -M html . _build -W
displayName: 'Sphinx Documentation Build check' displayName: 'Sphinx Documentation Build check'
- job: 'basic_test_pr_macOS' - job: 'ubuntu_1604_python35_legacy_torch'
pool:
vmImage: 'Ubuntu 16.04'
steps:
- script: |
python3 -m pip install --upgrade pip setuptools --user
python3 -m pip install coverage --user
echo "##vso[task.setvariable variable=PATH]${HOME}/.local/bin:${PATH}"
displayName: 'Install python tools'
- script: |
source install.sh
displayName: 'Install nni toolkit via source code'
- script: |
python3 -m pip install torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorflow==1.15.2 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install gym onnx --user
sudo apt-get install swig -y
nnictl package install --name=SMAC
nnictl package install --name=BOHB
displayName: 'Install dependencies'
- script: |
cd test
source scripts/unittest.sh
displayName: 'Unit test'
- script: |
cd test
python3 nni_test/nnitest/run_tests.py --config config/pr_tests.yml
displayName: 'Simple test'
- job: 'macos_1015_python37'
pool: pool:
vmImage: 'macOS-10.15' vmImage: 'macOS-10.15'
strategy:
matrix:
Python36:
PYTHON_VERSION: '3.6'
steps: steps:
- script: python3 -m pip install --upgrade pip setuptools - script: python3 -m pip install --upgrade pip setuptools
...@@ -79,9 +105,9 @@ jobs: ...@@ -79,9 +105,9 @@ jobs:
echo "##vso[task.setvariable variable=PATH]${HOME}/Library/Python/3.7/bin:${PATH}" echo "##vso[task.setvariable variable=PATH]${HOME}/Library/Python/3.7/bin:${PATH}"
displayName: 'Install nni toolkit via source code' displayName: 'Install nni toolkit via source code'
- script: | - script: |
python3 -m pip install torch==1.2.0 --user # pytorch Mac binary does not support CUDA, default is cpu version
python3 -m pip install torchvision==0.4.0 --user python3 -m pip install torchvision==0.6.0 torch==1.5.0 --user
python3 -m pip install tensorflow==1.13.1 --user python3 -m pip install tensorflow==1.15.2 --user
brew install swig@3 brew install swig@3
rm /usr/local/bin/swig rm /usr/local/bin/swig
ln -s /usr/local/opt/swig\@3/bin/swig /usr/local/bin/swig ln -s /usr/local/opt/swig\@3/bin/swig /usr/local/bin/swig
...@@ -96,13 +122,9 @@ jobs: ...@@ -96,13 +122,9 @@ jobs:
python3 nni_test/nnitest/run_tests.py --config config/pr_tests.yml python3 nni_test/nnitest/run_tests.py --config config/pr_tests.yml
displayName: 'Simple test' displayName: 'Simple test'
- job: 'basic_test_pr_Windows' - job: 'win2016_python37'
pool: pool:
vmImage: 'vs2017-win2016' vmImage: 'vs2017-win2016'
strategy:
matrix:
Python36:
PYTHON_VERSION: '3.6'
steps: steps:
- script: | - script: |
...@@ -111,9 +133,8 @@ jobs: ...@@ -111,9 +133,8 @@ jobs:
- script: | - script: |
python -m pip install scikit-learn==0.20.0 --user python -m pip install scikit-learn==0.20.0 --user
python -m pip install keras==2.1.6 --user python -m pip install keras==2.1.6 --user
python -m pip install torch===1.2.0 torchvision===0.4.1 -f https://download.pytorch.org/whl/torch_stable.html --user python -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install torchvision --user python -m pip install tensorflow==1.15.2 --user
python -m pip install tensorflow==1.13.1 --user
displayName: 'Install dependencies' displayName: 'Install dependencies'
- script: | - script: |
cd test cd test
......
...@@ -34,7 +34,7 @@ print('elapsed time: ', time.time() - start) ...@@ -34,7 +34,7 @@ print('elapsed time: ', time.time() - start)
``` ```
For complete examples please refer to [the code](https://github.com/microsoft/nni/tree/master/examples/model_compress/model_speedup.py) For complete examples please refer to [the code](https://github.com/microsoft/nni/tree/master/examples/model_compress/model_speedup.py)
NOTE: The current implementation only works on torch 1.3.1 and torchvision 0.4.2 NOTE: The current implementation supports PyTorch 1.3.1 or newer.
## Limitations ## Limitations
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import queue
import re
from collections import defaultdict
import torch
from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy
CLASSTYPE_KIND = 'ClassType'
GETATTR_KIND = 'prim::GetAttr'
_logger = logging.getLogger(__name__)
def build_module_graph(model, dummy_input):
return TorchModuleGraph(model, dummy_input)
def build_graph(model, dummy_input, verbose=False):
g = TorchProtoGraph(model, dummy_input, verbose)
return g.graph_def, g.stepstats
def parse_traced_name(module_name):
prefix = 'TracedModule['
suffix = ']'
if module_name.startswith(prefix) and module_name.endswith(suffix):
module_name = module_name[len(prefix):-len(suffix)]
return module_name
class TorchGraph:
"""
This class is to extract pytorch model topology graph by tracing
"""
def __init__(self, model, dummy_input):
"""
Parameters
----------
model : pytorch model
The model user wants to speed up
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in
"""
assert torch.__version__ >= '1.3.1'
self.bound_model = model
self._trace(model, dummy_input)
def _trace(self, model, dummy_input):
with torch.onnx.set_training(model, False):
self.trace = torch.jit.trace(model, dummy_input)
torch._C._jit_pass_inline(self.trace.graph)
class TorchProtoGraph(TorchGraph):
"""
Generates model graph for pytorch models in protobuf, this implementation is borrowed from pytorch v1.4.0,
and fixed following issues:
https://github.com/pytorch/pytorch/issues/33691
https://github.com/pytorch/pytorch/issues/33670
"""
def __init__(self, model, dummy_input, verbose=False):
super().__init__(model, dummy_input)
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
from tensorboard.compat.proto.versions_pb2 import VersionDef
list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input)
if verbose:
print(self.trace.graph)
self.stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
self.graph_def = GraphDef(node=list_of_nodes, versions=VersionDef(producer=22))
def parse(self, graph, trace, args=None, omit_useless_nodes=True):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if not node.uses(): # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != CLASSTYPE_KIND:
nodes_py.append(NodePyIO(node, 'input'))
attr_to_scope = dict()
node_to_name = lambda d: str(d).split(":")[0].strip()
for node in graph.nodes():
if node.kind() == GETATTR_KIND:
attr_name = node.s('name')
node_name = node_to_name(node)
parent = node.input().node()
if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node
parent_scope = attr_to_scope[node_to_name(parent)]
attr_scope = parent_scope.split('/')[-1]
attr_to_scope[node_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name)
else:
attr_to_scope[node_name] = '__module.{}'.format(attr_name)
# We don't need classtype nodes; scope will provide this information
if node.output().type().kind() != CLASSTYPE_KIND:
node_py = NodePyOP(node)
node_py.scopeName = attr_to_scope[node_name]
nodes_py.append(node_py)
else:
nodes_py.append(NodePyOP(node))
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
node_py.debugName = "output.{}".format(i + 1)
node_py.inputs = [node.debugName()]
nodes_py.append(node_py)
alias_to_name = dict()
base_name = parse_traced_name(trace._name)
for name, module in trace.named_modules(prefix='__module'):
mod_name = parse_traced_name(module._name)
attr_name = name.split('.')[-1]
alias_to_name[name] = '{}[{}]'.format(mod_name, attr_name)
for node in nodes_py.nodes_op:
module_aliases = node.scopeName.split('/')[-1].split('.')
module_name = ''
for i, alias in enumerate(module_aliases):
if i == 0:
module_name = alias
node.scopeName = base_name
else:
module_name += '.' + alias
node.scopeName += '/' + (alias_to_name[module_name] if module_name in alias_to_name else alias)
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
class NodePyGroup(NodePy):
"""
This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph,
there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
"""
def __init__(self, name, node_type, op_type, node_cpps, inputs=None, outputs=None):
"""
Parameters:
-----------
name: str
node name, such as `conv1`, `backbone.classifier`
node_type: str
`module` or `func`
op_type: str
operation type, such as `Conv2d`, `aten::view`
node_cpps: list of torch._C.Node
jit trace nodes which are included in this new node
inputs: list of str
All the inputs of this node, each element is debugName of one input
outputs: list of str
All the outputs of this node, each element is debugName of one output
"""
super(NodePyGroup, self).__init__(name, [])
self.node_cpps = node_cpps
self.name = name
self.op_type = op_type
self.type = node_type
self.nodes = []
self.auxiliary = None
self.add_nodes(node_cpps)
self.inputs = inputs
self.outputs = outputs
def add_nodes(self, node_cpps):
for node_cpp in node_cpps:
nodepy = NodePyOP(node_cpp)
nodepy.name = str(node_cpp).split(':')[0].strip().replace('%', '')
self.nodes.append(nodepy)
def sub_node_names(self):
return [x.name for x in self.nodes]
def __repr__(self):
return 'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'.format(
self.name, self.type, self.op_type, self.sub_node_names(), self.inputs, self.outputs, self.auxiliary
)
class TorchModuleGraph(TorchGraph):
"""
Generates model graph, each node is created from single or multiple jit trace nodes.
"""
def __init__(self, model, dummy_input):
super().__init__(model, dummy_input)
self.global_count = 0
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
trivial op which are label by ```prim::```, some of them are not such ops which is call
non-prim ops. This function is to merge neighbor prim ops to a non-prim op, to construct
a node.
Parameters
----------
node : trace graph node
The non-prim node to expand
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
Returns
-------
node
the expanded non-prim node
"""
# TODO: scope name could be empty
node_name = '.'.join([self._get_module_name(node.scopeName()), node.kind(), str(self.global_count)])
_logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
node_queue = queue.Queue()
node_queue.put(node)
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if predecessor_node.kind().startswith('prim::'):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
inputs.append(input_name)
else:
inputs.append(input_name)
for output in node.outputs():
outputs.append(output.debugName())
nodepy = NodePyGroup(node_name, 'func', op_type, node_group, inputs=inputs, outputs=outputs)
return nodepy
def _build_module_node_group(self, module_name, op_type, node_cpps, input_to_node, output_to_node):
graph = self.trace.graph
inputs, outputs = [], []
for n in node_cpps:
for i in n.inputs():
name = i.debugName()
if not name in output_to_node and i in graph.inputs():
inputs.append(name)
elif output_to_node[name] not in node_cpps:
inputs.append(name)
for o in n.outputs():
name = o.debugName()
if not name in input_to_node and o in graph.outputs():
outputs.append(name)
elif input_to_node[name] not in node_cpps:
outputs.append(name)
return NodePyGroup(module_name, 'module', op_type, node_cpps, inputs, outputs)
def _extract_shape_info(self, node):
"""
Extract the shape information of ```aten::view``` node
Parameters
----------
node : trace graph node
It should be ```aten::view``` node
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
t_input = None
for _input in node.inputs():
t_input = _input
break
t_output = node.output()
assert isinstance(t_input.type(), torch._C.TensorType)
assert isinstance(t_output.type(), torch._C.TensorType)
in_shape = t_input.type().sizes()
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
def _extract_leaf_modules(self):
"""
Extract leaf modules from the given graph. Leaf module means it does not have submodules.
To extract leaf modules because only leaf module can be replaced. And shape inference can
be done in leaf module level. Other shape inference is done in lower level i.e.,
operation level.
Returns
-------
list
a list of scope name of all the leaf modules
"""
def is_parent(name1, name2):
"""
check if name1 is parent node of name2, for example:
name1: aa.bb, name2: aa.bb.cc, return True
name1: aa.b, name2: aa.bb, return False
"""
parts1, parts2 = name1.split('.'), name2.split('.')
if len(parts1) >= len(parts2):
return False
for i in range(len(parts1)):
if parts2[i] != parts1[i]:
return False
return True
module_names = sorted([x[0] for x in self.trace.named_modules() if x[0]])
leaf_nodes = []
for i, name in enumerate(module_names):
if i + 1 >= len(module_names) or not is_parent(name, module_names[i + 1]):
leaf_nodes.append(name)
return leaf_nodes
def _get_module_name(self, scope_name):
"""
Retrieve module name from scope name.
Parameters:
-----------
scope_name: str
scope_name of a graph node, for example:
for pytorch 1.3.1: MyModel/BackboneModel[backbone]/Conv2d[conv2]
for pytorch 1.4.0: __module.backbone/__module.backbone.conv2
Returns:
-------
str
module name, such as backbone.conv2
"""
if torch.__version__ >= '1.4.0':
return scope_name.split('/')[-1].replace('__module.', '')
else:
return '.'.join(re.findall(r'\[(.*?)\]', scope_name))
def _build_index(self, nodes_op):
name_to_node = dict()
input_to_node = defaultdict(list)
output_to_node = dict()
for node in nodes_op:
name_to_node[node.name] = node
for _input in node.inputs:
input_to_node[_input].append(node)
for output in node.outputs:
assert not output in output_to_node, \
"One output cannot be generated by multiple nodes"
output_to_node[output] = node
return name_to_node, input_to_node, output_to_node
def _build_graph(self):
"""
Build graph using our defined format from jit trace.
There are basically three steps: first, construct necessary information (data structures),
second, extract all the modules to convert to node, Third, extract all functions to convert
to node.
Returns
-------
dict
use name to index nodes, key: node name, value: node
dict
use input (its name) to index nodes,
key: input, value: list of nodes that take this input
dict
use output (its name) to index nodes,
key: output, value: node that generates this output
"""
omit_useless_nodes = True
graph = self.trace.graph
_logger.debug(graph)
# build output mapping, from output debugName to its node
output_to_node = {x.debugName(): n for n in graph.nodes() for x in n.outputs()}
# build input mapping, from input debugName to its node
input_to_node = {x.debugName(): n for n in graph.nodes() for x in n.inputs()}
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = defaultdict(list)
# the mapping of function (non-module in forward) to nodes, key is scope name
func_to_nodes = defaultdict(list)
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if not node.uses(): # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != 'ClassType':
nodes_py.append(NodePyIO(node, 'input'))
self.leaf_modules = self._extract_leaf_modules()
module_to_type = {name: parse_traced_name(module._name) for name, module in self.trace.named_modules()}
# associate module name with their trace graph nodes
for node in graph.nodes():
module_name = self._get_module_name(node.scopeName())
if module_name in self.leaf_modules:
module_to_nodes[module_name].append(node)
else:
func_to_nodes[node.scopeName()].append(node)
# build node group for module
for module_name, node_cpps in module_to_nodes.items():
node_group = self._build_module_node_group(
module_name, module_to_type[module_name], node_cpps, input_to_node, output_to_node
)
_logger.debug('node_group: %s', node_group)
nodes_py.nodes_op.append(node_group)
# each scope_name may have multiple funcs, we split them and create node for each of them
# build node group for torch.nn.functional
for _, nodes in func_to_nodes.items():
# extract non prim:: nodes
non_prim_nodes = list()
for node in nodes:
if not node.kind().startswith('prim::'):
non_prim_nodes.append(node)
# for each non prim node, expand it
for node in non_prim_nodes:
node_group = self._expand_non_prim_node(node, nodes, input_to_node, output_to_node)
nodes_py.nodes_op.append(node_group)
# get shape infor for view (aten::view) func
if node_group.op_type in ['aten::view', 'aten::flatten']:
node_group.auxiliary = self._extract_shape_info(node)
for node in graph.outputs(): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
nodes_py.append(node_py)
self.nodes_py = nodes_py
# build index
return self._build_index(self.nodes_py.nodes_op)
def find_predecessors(self, module_name):
"""
Find predecessor node of the given node
Parameters
----------
module_name : str
The name of the node
Returns
-------
list
a list of nodes who are the given node's predecessor
"""
predecessors = []
for _input in self.name_to_node[module_name].inputs:
if not _input in self.output_to_node:
_logger.debug("cannot find node with %s as its output", _input)
else:
node_py = self.output_to_node[_input]
predecessors.append(node_py.name)
return predecessors
def find_successors(self, module_name):
"""
Find successor nodes of the given node
Parameters
----------
module_name : str
The name of the node
Returns
-------
list
a list of nodes who are the given node's successor
"""
successors = []
for output in self.name_to_node[module_name].outputs:
assert output in self.input_to_node, "No node with input {}".format(output)
nodes_py = self.input_to_node[output]
for node_py in nodes_py:
successors.append(node_py.name)
return successors
...@@ -12,6 +12,7 @@ replace_module = { ...@@ -12,6 +12,7 @@ replace_module = {
'Conv2d': lambda module, mask: replace_conv2d(module, mask), 'Conv2d': lambda module, mask: replace_conv2d(module, mask),
'MaxPool2d': lambda module, mask: no_replace(module, mask), 'MaxPool2d': lambda module, mask: no_replace(module, mask),
'AvgPool2d': lambda module, mask: no_replace(module, mask), 'AvgPool2d': lambda module, mask: no_replace(module, mask),
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask), 'ReLU': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask) 'Linear': lambda module, mask: replace_linear(module, mask)
} }
......
...@@ -2,9 +2,8 @@ ...@@ -2,9 +2,8 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import queue
import re
import torch import torch
from nni._graph_utils import build_module_graph
from .compress_modules import replace_module from .compress_modules import replace_module
from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape from .infer_shape import ModuleMasks, infer_from_mask, infer_from_inshape, infer_from_outshape
...@@ -33,38 +32,6 @@ def get_module_by_name(model, module_name): ...@@ -33,38 +32,6 @@ def get_module_by_name(model, module_name):
leaf_module = getattr(model, name_list[-1]) leaf_module = getattr(model, name_list[-1])
return model, leaf_module return model, leaf_module
class GNode:
"""
It is used to represent a node in model graph, in this graph a module is a node,
a function out of module (in ```forward``` function) could also be a node.
"""
def __init__(self, node_name, node_type, op_type, inputs, outputs, nodes):
"""
Parameters
----------
node_name : str
It is module name if the node is a module, it is ```scope_name.node_kind.seq``` if it is a func
node_type : str
It only has two options: `module` or `func`
op_type : str
The operation type of the module or func
inputs : list of str
All the inputs of this node, each element is debugName of one input
outputs : list of str
All the outputs of this node, each element is debugName of one output
nodes : list of node
All the trace graph nodes included in this module or func
"""
self.name = node_name
self.type = node_type
self.op_type = op_type
self.inputs = inputs
self.outputs = outputs
self.nodes = nodes
# store supplementary information for different op types
# for example, for ```view``` it stores the shape of its input and output
self.auxiliary = None
class ModelSpeedup: class ModelSpeedup:
""" """
This class is to speedup the model with provided weight mask This class is to speedup the model with provided weight mask
...@@ -84,347 +51,9 @@ class ModelSpeedup: ...@@ -84,347 +51,9 @@ class ModelSpeedup:
the device on which masks are placed, same to map_location in ```torch.load``` the device on which masks are placed, same to map_location in ```torch.load```
""" """
self.bound_model = model self.bound_model = model
self.dummy_input = dummy_input
self.masks = torch.load(masks_file, map_location) self.masks = torch.load(masks_file, map_location)
self.is_training = model.training
# to obtain forward graph, model should be in ```eval``` mode
if self.is_training:
model.eval()
self.trace_graph = torch.jit.trace(model, dummy_input)
if self.is_training:
model.train()
self.inferred_masks = dict() # key: module_name, value: ModuleMasks self.inferred_masks = dict() # key: module_name, value: ModuleMasks
self.g_nodes = list() self.torch_graph = build_module_graph(model, dummy_input)
self.global_count = 0
self.name_to_gnode, self.input_to_gnode, self.output_to_gnode = self._build_graph()
def _build_index_for_gnodes(self, g_nodes):
"""
Build indexes for quick search
Parameters
----------
g_nodes : list of GNode
All the g_node in processed model graph
Returns
-------
dict
use name to index g_nodes, key: node name, value: g_node
dict
use input (its name) to index g_nodes,
key: input, value: list of g_nodes that take this input
dict
use output (its name) to index g_nodes,
key: output, value: g_node that generates this output
"""
name_to_gnode = dict()
input_to_gnode = dict()
output_to_gnode = dict()
for node in g_nodes:
name_to_gnode[node.name] = node
for _input in node.inputs:
if _input in input_to_gnode:
input_to_gnode[_input].append(node)
else:
input_to_gnode[_input] = [node]
for output in node.outputs:
assert not output in output_to_gnode, \
"One output cannot be generated by multiple nodes"
output_to_gnode[output] = node
return name_to_gnode, input_to_gnode, output_to_gnode
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
trivial op which are label by ```prim::```, some of them are not such ops which is call
non-prim ops. This function is to merge neighbor prim ops to a non-prim op, to construct
a GNode.
Parameters
----------
node : trace graph node
The non-prim node to expand
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
Returns
-------
GNode
the expanded non-prim node in GNode format
"""
# TODO: scope name could be empty
node_name = '.'.join([node.scopeName(), node.kind(), str(self.global_count)])
_logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
node_queue = queue.Queue()
node_queue.put(node)
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if predecessor_node.kind().startswith('prim::'):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
inputs.append(input_name)
else:
inputs.append(input_name)
for output in node.outputs():
outputs.append(output.debugName())
g_node = GNode(node_name, 'func', op_type, inputs, outputs, node_group)
return g_node
def _extract_shape_info(self, node):
"""
Extract the shape information of ```aten::view``` node
Parameters
----------
node : trace graph node
It should be ```aten::view``` node
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
t_input = None
for _input in node.inputs():
t_input = _input
break
t_output = node.output()
assert isinstance(t_input.type(), torch._C.TensorType)
assert isinstance(t_output.type(), torch._C.TensorType)
in_shape = t_input.type().sizes()
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
def _extract_leaf_modules(self, graph):
"""
Extract leaf modules from the given graph. Leaf module means it does not have submodules.
To extract leaf modules because only leaf module can be replaced. And shape inference can
be done in leaf module level. Other shape inference is done in lower level i.e.,
operation level.
Parameters
----------
graph : jit trace graph
the graph generated from jit trace
Returns
-------
list
a list of scope name of all the leaf modules
"""
class SNode:
def __init__(self, name):
self.sname = name
self.childs = {}
root = None
for node in graph.nodes():
scope_name = node.scopeName()
if scope_name == '':
continue
segs = scope_name.split('/')
if root is None:
root = SNode(segs[0])
curr = root
for seg in segs[1:]:
if not seg in curr.childs:
curr.childs[seg] = SNode(seg)
curr = curr.childs[seg]
leaf_nodes = []
def traverse_tree(node, scope_name):
if scope_name == '':
sn = node.sname
else:
sn = scope_name + '/' + node.sname
if not node.childs:
if node.sname[-1] == ']':
leaf_nodes.append(sn)
else:
for key in node.childs:
traverse_tree(node.childs[key], sn)
traverse_tree(root, '')
return leaf_nodes
def _build_graph(self):
"""
Build graph using our defined format from jit trace.
There are basically three steps: first, construct necessary information (data structures),
second, extract all the modules to convert to GNode, Third, extract all functions to convert
to GNode.
Returns
-------
dict
use name to index g_nodes, key: node name, value: g_node
dict
use input (its name) to index g_nodes,
key: input, value: list of g_nodes that take this input
dict
use output (its name) to index g_nodes,
key: output, value: g_node that generates this output
"""
graph = self.trace_graph.graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
_logger.debug(graph)
# build output mapping, from output debugName to its node
output_to_node = dict()
# build input mapping, from input debugName to its node
input_to_node = dict()
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = dict()
# module name to its type
module_to_type = dict()
# the mapping of function (non-module in forward) to nodes, key is scope name
func_to_nodes = dict()
graph_inputs = list()
graph_outputs = list()
for _input in graph.inputs():
graph_inputs.append(_input.debugName())
for output in graph.outputs():
graph_outputs.append(output.debugName())
leaf_modules = self._extract_leaf_modules(graph)
_logger.debug(leaf_modules)
for node in graph.nodes():
# populate output_to_node and input_to_node
for output in node.outputs():
output_name = output.debugName()
output_to_node[output_name] = node
for _input in node.inputs():
input_name = _input.debugName()
input_to_node[input_name] = node
scope_name = node.scopeName() # example: scope_name, 'MyCell/Linear[linear]'
# if module_name is empty, it is not a module
if not scope_name in leaf_modules:
if scope_name == '':
continue
else:
if scope_name in func_to_nodes:
func_to_nodes[scope_name].append(node)
else:
func_to_nodes[scope_name] = [node]
else:
module_name_slices = re.findall(r'\[(.*?)\]', scope_name)
module_name = '.'.join(module_name_slices)
scope_slice = scope_name.split('/')[-1]
module_type = scope_slice.split('[')[0]
module_to_type[module_name] = module_type
if module_name in module_to_nodes:
module_to_nodes[module_name].append(node)
else:
module_to_nodes[module_name] = [node]
# construct GNode from module
for module_name, nodes in module_to_nodes.items():
inputs = set()
outputs = set()
for node in nodes:
for output in node.outputs():
outputs.add(output.debugName())
for _input in node.inputs():
inputs.add(_input.debugName())
m_inputs = list()
m_outputs = list()
for output in outputs:
# TODO: one input could be the input of multiple nodes
if not output in input_to_node and output in graph_outputs:
m_outputs.append(output)
elif not input_to_node[output] in nodes:
m_outputs.append(output)
for _input in inputs:
if not _input in output_to_node and _input in graph_inputs:
m_inputs.append(_input)
elif not output_to_node[_input] in nodes:
m_inputs.append(_input)
if module_name == '':
_logger.warning("module_name is empty string")
g_node = GNode(module_name, 'module', module_to_type[module_name], m_inputs, m_outputs, nodes)
self.g_nodes.append(g_node)
# each scope_name may have multiple funcs, we split them and create GNode for each of them
for scope_name, nodes in func_to_nodes.items():
# extract non prim:: nodes
non_prim_nodes = list()
for node in nodes:
if not node.kind().startswith('prim::'):
non_prim_nodes.append(node)
# for each non prim node, expand it has a GNode
for node in non_prim_nodes:
g_node = self._expand_non_prim_node(node, nodes, input_to_node, output_to_node)
self.g_nodes.append(g_node)
# get shape infor for view (aten::view) func
if g_node.op_type == 'aten::view':
g_node.auxiliary = self._extract_shape_info(node)
# build index for g_nodes
name_to_gnode, input_to_gnode, output_to_gnode = self._build_index_for_gnodes(self.g_nodes)
return name_to_gnode, input_to_gnode, output_to_gnode
def _find_predecessors(self, module_name):
"""
Find predecessor GNode of the given GNode
Parameters
----------
module_name : str
The name of the GNode
Returns
-------
list
a list of GNodes who are the given GNode's predecessor
"""
predecessors = []
for _input in self.name_to_gnode[module_name].inputs:
if not _input in self.output_to_gnode:
_logger.debug("cannot find gnode with %s as its output", _input)
else:
g_node = self.output_to_gnode[_input]
predecessors.append(g_node.name)
return predecessors
def _find_successors(self, module_name):
"""
Find successor GNodes of the given GNode
Parameters
----------
module_name : str
The name of the GNode
Returns
-------
list
a list of GNodes who are the given GNode's successor
"""
successors = []
for output in self.name_to_gnode[module_name].outputs:
assert output in self.input_to_gnode, "No gnode with input {}".format(output)
g_nodes = self.input_to_gnode[output]
for g_node in g_nodes:
successors.append(g_node.name)
return successors
def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None): def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None):
""" """
...@@ -441,13 +70,13 @@ class ModelSpeedup: ...@@ -441,13 +70,13 @@ class ModelSpeedup:
Parameters Parameters
---------- ----------
module_name : str module_name : str
The name of the GNode The name of the node
mask : tensor of mask or ModuleMasks mask : tensor of mask or ModuleMasks
Mask of the weights in this GNode (i.e., module) Mask of the weights in this node (i.e., module)
in_shape : ModuleMasks in_shape : ModuleMasks
Input shape of this GNode Input shape of this node
out_shape : ModuleMasks out_shape : ModuleMasks
Output shape of this GNode Output shape of this node
""" """
input_cmask = output_cmask = None input_cmask = output_cmask = None
if module_name in self.inferred_masks: if module_name in self.inferred_masks:
...@@ -456,7 +85,7 @@ class ModelSpeedup: ...@@ -456,7 +85,7 @@ class ModelSpeedup:
module_masks = ModuleMasks(module_name) module_masks = ModuleMasks(module_name)
self.inferred_masks[module_name] = module_masks self.inferred_masks[module_name] = module_masks
m_type = self.name_to_gnode[module_name].op_type m_type = self.torch_graph.name_to_node[module_name].op_type
_logger.debug("infer mask of module %s with op_type %s", module_name, m_type) _logger.debug("infer mask of module %s with op_type %s", module_name, m_type)
if mask is not None: if mask is not None:
_logger.debug("mask is not None") _logger.debug("mask is not None")
...@@ -471,10 +100,10 @@ class ModelSpeedup: ...@@ -471,10 +100,10 @@ class ModelSpeedup:
raise RuntimeError( raise RuntimeError(
"Has not supported infering output shape from input shape for module/function: `{}`, {}" "Has not supported infering output shape from input shape for module/function: `{}`, {}"
.format(m_type, module_name)) .format(m_type, module_name))
if m_type == 'aten::view': if m_type in ['aten::view', 'aten::flatten']:
output_cmask = infer_from_inshape[m_type](module_masks, output_cmask = infer_from_inshape[m_type](module_masks,
in_shape, in_shape,
self.name_to_gnode[module_name].auxiliary) self.torch_graph.name_to_node[module_name].auxiliary)
else: else:
output_cmask = infer_from_inshape[m_type](module_masks, in_shape) output_cmask = infer_from_inshape[m_type](module_masks, in_shape)
if out_shape is not None: if out_shape is not None:
...@@ -486,11 +115,11 @@ class ModelSpeedup: ...@@ -486,11 +115,11 @@ class ModelSpeedup:
input_cmask = infer_from_outshape[m_type](module_masks, out_shape) input_cmask = infer_from_outshape[m_type](module_masks, out_shape)
if input_cmask: if input_cmask:
predecessors = self._find_predecessors(module_name) predecessors = self.torch_graph.find_predecessors(module_name)
for _module_name in predecessors: for _module_name in predecessors:
self.infer_module_mask(_module_name, out_shape=input_cmask) self.infer_module_mask(_module_name, out_shape=input_cmask)
if output_cmask: if output_cmask:
successors = self._find_successors(module_name) successors = self.torch_graph.find_successors(module_name)
for _module_name in successors: for _module_name in successors:
self.infer_module_mask(_module_name, in_shape=output_cmask) self.infer_module_mask(_module_name, in_shape=output_cmask)
...@@ -511,7 +140,7 @@ class ModelSpeedup: ...@@ -511,7 +140,7 @@ class ModelSpeedup:
is that ```func``` should be not required to be replaced. is that ```func``` should be not required to be replaced.
""" """
for module_name in self.inferred_masks: for module_name in self.inferred_masks:
g_node = self.name_to_gnode[module_name] g_node = self.torch_graph.name_to_node[module_name]
_logger.debug("replace %s, in %s type, with op_type %s", _logger.debug("replace %s, in %s type, with op_type %s",
module_name, g_node.type, g_node.op_type) module_name, g_node.type, g_node.op_type)
if g_node.type == 'module': if g_node.type == 'module':
...@@ -526,7 +155,7 @@ class ModelSpeedup: ...@@ -526,7 +155,7 @@ class ModelSpeedup:
_logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type", _logger.info("Warning: cannot replace (name: %s, op_type: %s) which is func type",
module_name, g_node.op_type) module_name, g_node.op_type)
else: else:
raise RuntimeError("Unsupported GNode type: {}".format(g_node.type)) raise RuntimeError("Unsupported node type: {}".format(g_node.type))
def speedup_model(self): def speedup_model(self):
""" """
...@@ -540,8 +169,3 @@ class ModelSpeedup: ...@@ -540,8 +169,3 @@ class ModelSpeedup:
_logger.info("replace compressed modules...") _logger.info("replace compressed modules...")
self.replace_compressed_modules() self.replace_compressed_modules()
_logger.info("speedup done") _logger.info("speedup done")
# resume the model mode to that before the model is speed up
if self.is_training:
self.bound_model.train()
else:
self.bound_model.eval()
\ No newline at end of file
...@@ -83,6 +83,9 @@ class CoarseMask: ...@@ -83,6 +83,9 @@ class CoarseMask:
cmask.mask_index[i]) cmask.mask_index[i])
return self.mask_index return self.mask_index
def __repr__(self):
return 'mask_index: {}'.format(self.mask_index)
class ModuleMasks: class ModuleMasks:
""" """
The masks of a module, including the masks for weights, inputs, output The masks of a module, including the masks for weights, inputs, output
...@@ -128,6 +131,11 @@ class ModuleMasks: ...@@ -128,6 +131,11 @@ class ModuleMasks:
""" """
self.output_mask = mask self.output_mask = mask
def __repr__(self):
return 'input_mask: {}, output_mask: {}, param_masks: {}'.format(
self.input_mask, self.output_mask, self.param_masks
)
""" """
Infer input and output shape of a module/function from its weight mask Infer input and output shape of a module/function from its weight mask
""" """
...@@ -147,8 +155,10 @@ infer_from_inshape = { ...@@ -147,8 +155,10 @@ infer_from_inshape = {
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'aten::avg_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'AvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'AvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::size': lambda module_masks, mask: size_inshape(module_masks, mask), 'aten::size': lambda module_masks, mask: size_inshape(module_masks, mask),
'aten::view': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape), 'aten::view': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape),
'aten::flatten': lambda module_masks, mask, shape: view_inshape(module_masks, mask, shape), # support only start_dim=1
'Linear': lambda module_masks, mask: linear_inshape(module_masks, mask), 'Linear': lambda module_masks, mask: linear_inshape(module_masks, mask),
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask) 'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask)
} }
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
# This file is copied from PyTorch 1.4, with bug fixes.
# Likely to be removed in future.
import torch
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
from tensorboard.compat.proto.versions_pb2 import VersionDef
from torch.utils.tensorboard._pytorch_graph import GraphPy, CLASSTYPE_KIND, GETATTR_KIND, NodePyIO, NodePyOP
def parse(graph, trace, args=None, omit_useless_nodes=True):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
n_inputs = len(args)
scope = {}
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != CLASSTYPE_KIND:
nodes_py.append(NodePyIO(node, 'input'))
attr_to_scope = dict()
node_to_name = lambda d: str(d).split(":")[0].strip()
for node in graph.nodes():
if node.kind() == GETATTR_KIND:
attr_name = node.s('name')
node_name = node_to_name(node)
parent = node.input().node()
if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node
parent_attr_name = parent.s('name')
parent_scope = attr_to_scope[node_to_name(parent)]
attr_scope = parent_scope.split('/')[-1]
attr_to_scope[node_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name)
else:
attr_to_scope[node_name] = '__module.{}'.format(attr_name)
# We don't need classtype nodes; scope will provide this information
if node.output().type().kind() != CLASSTYPE_KIND:
node_py = NodePyOP(node)
node_py.scopeName = attr_to_scope[node_name]
nodes_py.append(node_py)
else:
nodes_py.append(NodePyOP(node))
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
node_py.debugName = "output.{}".format(i + 1)
node_py.inputs = [node.debugName()]
nodes_py.append(node_py)
def parse_traced_name(module_name):
prefix = 'TracedModule['
suffix = ']'
if module_name.startswith(prefix) and module_name.endswith(suffix):
module_name = module_name[len(prefix):-len(suffix)]
return module_name
alias_to_name = dict()
base_name = parse_traced_name(trace._name)
for name, module in trace.named_modules(prefix='__module'):
mod_name = parse_traced_name(module._name)
attr_name = name.split('.')[-1]
alias_to_name[name] = '{}[{}]'.format(mod_name, attr_name)
for node in nodes_py.nodes_op:
module_aliases = node.scopeName.split('/')[-1].split('.')
module_name = ''
for i, alias in enumerate(module_aliases):
if i == 0:
module_name = alias
node.scopeName = base_name
else:
module_name += '.' + alias
node.scopeName += '/' + (alias_to_name[module_name] if module_name in alias_to_name else alias)
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
def graph(model, args, verbose=False):
"""
This method processes a PyTorch model and produces a `GraphDef` proto
that can be logged to TensorBoard.
Args:
model (PyTorch module): The model to be parsed.
args (tuple): input tensor[s] for the model.
verbose (bool): Whether to print out verbose information while
processing.
"""
with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx?
try:
trace = torch.jit.trace(model, args)
graph = trace.graph
torch._C._jit_pass_inline(graph)
except RuntimeError as e:
print(e)
print('Error occurs, No graph saved')
raise e
if verbose:
print(graph)
list_of_nodes = parse(graph, trace, args)
# We are hardcoding that this was run on CPU even though it might have actually
# run on GPU. Note this is what is shown in TensorBoard and has no bearing
# on actual execution.
# TODO: See if we can extract GPU vs CPU information from the PyTorch model
# and pass it correctly to TensorBoard.
#
# Definition of StepStats and DeviceStepStats can be found at
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
# and
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
# The producer version has been reverse engineered from standard
# TensorBoard logged data.
...@@ -107,12 +107,12 @@ class Mutator(BaseMutator): ...@@ -107,12 +107,12 @@ class Mutator(BaseMutator):
""" """
if not torch.__version__.startswith("1.4"): if not torch.__version__.startswith("1.4"):
logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.") logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.")
from ._graph_utils import graph from nni._graph_utils import build_graph
from google.protobuf import json_format from google.protobuf import json_format
# protobuf should be installed as long as tensorboard is installed # protobuf should be installed as long as tensorboard is installed
try: try:
self._connect_all = True self._connect_all = True
graph_def, _ = graph(self.model, inputs, verbose=False) graph_def, _ = build_graph(self.model, inputs, verbose=False)
result = json_format.MessageToDict(graph_def) result = json_format.MessageToDict(graph_def)
finally: finally:
self._connect_all = False self._connect_all = False
......
node {
name: "input/input"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "myLinear/Linear[l]/22"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "myLinear/Linear[l]/bias/17"
op: "prim::GetAttr"
input: "myLinear/Linear[l]/weight/14"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "myLinear/Linear[l]/weight/18"
op: "prim::GetAttr"
input: "myLinear/Linear[l]/weight/14"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "myLinear/Linear[l]/19"
op: "aten::t"
input: "myLinear/Linear[l]/weight/18"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 3
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "myLinear/Linear[l]/20"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/21"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/22"
op: "aten::addmm"
input: "myLinear/Linear[l]/bias/17"
input: "input/input"
input: "myLinear/Linear[l]/19"
input: "myLinear/Linear[l]/20"
input: "myLinear/Linear[l]/21"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}
node {
name: "input/input.1"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "input/input.1"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "MyModule/Linear[weight]/bias/49"
op: "prim::GetAttr"
input: "MyModule/Linear[weight]/weight/35"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/Linear[weight]/weight/50"
op: "prim::GetAttr"
input: "MyModule/Linear[weight]/weight/35"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/Linear[weight]/51"
op: "aten::t"
input: "MyModule/Linear[weight]/weight/50"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/Linear[weight]/52"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[weight]/53"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[weight]/54"
op: "aten::addmm"
input: "MyModule/Linear[weight]/bias/49"
input: "input/input.1"
input: "MyModule/Linear[weight]/51"
input: "MyModule/Linear[weight]/52"
input: "MyModule/Linear[weight]/53"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/Linear[bias]/bias/55"
op: "prim::GetAttr"
input: "MyModule/Linear[bias]/weight/38"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/Linear[bias]/weight/56"
op: "prim::GetAttr"
input: "MyModule/Linear[bias]/weight/38"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/Linear[bias]/57"
op: "aten::t"
input: "MyModule/Linear[bias]/weight/56"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/Linear[bias]/58"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[bias]/59"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[bias]/60"
op: "aten::addmm"
input: "MyModule/Linear[bias]/bias/55"
input: "input/input.1"
input: "MyModule/Linear[bias]/57"
input: "MyModule/Linear[bias]/58"
input: "MyModule/Linear[bias]/59"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/23"
op: "prim::ListConstruct"
input: "MyModule/Linear[weight]/54"
input: "MyModule/Linear[bias]/60"
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/24"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/input"
op: "aten::cat"
input: "MyModule/23"
input: "MyModule/24"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 6
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/61"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}
node {
name: "input/input.1"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "MyModule/ModuleList[module]/Linear[1]/46"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 1
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/bias/35"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[0]/weight/26"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/weight/36"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[0]/weight/26"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/37"
op: "aten::t"
input: "MyModule/ModuleList[module]/Linear[0]/weight/36"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/38"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/39"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/input"
op: "aten::addmm"
input: "MyModule/ModuleList[module]/Linear[0]/bias/35"
input: "input/input.1"
input: "MyModule/ModuleList[module]/Linear[0]/37"
input: "MyModule/ModuleList[module]/Linear[0]/38"
input: "MyModule/ModuleList[module]/Linear[0]/39"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/bias/41"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[1]/weight/30"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/weight/42"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[1]/weight/30"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/43"
op: "aten::t"
input: "MyModule/ModuleList[module]/Linear[1]/weight/42"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 3
}
dim {
size: 1
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/44"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/45"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/46"
op: "aten::addmm"
input: "MyModule/ModuleList[module]/Linear[1]/bias/41"
input: "MyModule/ModuleList[module]/Linear[0]/input"
input: "MyModule/ModuleList[module]/Linear[1]/43"
input: "MyModule/ModuleList[module]/Linear[1]/44"
input: "MyModule/ModuleList[module]/Linear[1]/45"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 1
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
import os
import math
import uuid
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorboard.compat.proto.graph_pb2 import GraphDef
from google.protobuf import text_format
import unittest
from unittest import TestCase, main
from nni._graph_utils import build_module_graph, build_graph
class BackboneModel1(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1, 1)
def forward(self, x):
return self.conv1(x)
class BackboneModel2(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class BigModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.backbone1 = BackboneModel1()
self.backbone2 = BackboneModel2()
self.fc3 = nn.Linear(10, 2)
def forward(self, x):
x = self.backbone1(x)
x = self.backbone2(x)
x = self.fc3(x)
return x
class GraphUtilsTestCase(TestCase):
def test_build_module_graph(self):
big_model = BigModel()
g = build_module_graph(big_model, torch.randn(2, 1, 28, 28))
print(g.name_to_node.keys())
leaf_modules = set([
'backbone1.conv1', 'backbone2.bn1', 'backbone2.bn2', 'backbone2.conv1',
'backbone2.conv2', 'backbone2.fc1', 'backbone2.fc2', 'fc3'
])
assert set(g.leaf_modules) == leaf_modules
assert not leaf_modules - set(g.name_to_node.keys())
assert g.find_successors('backbone2.conv1') == ['backbone2.bn1']
assert g.find_successors('backbone2.conv2') == ['backbone2.bn2']
assert g.find_predecessors('backbone2.bn1') == ['backbone2.conv1']
assert g.find_predecessors('backbone2.bn2') == ['backbone2.conv2']
def _test_graph(self, model, dummy_input, expected_file):
actual_proto, _ = build_graph(model, dummy_input)
assert os.path.exists(expected_file), expected_file
with open(expected_file, "r") as f:
expected_str = f.read()
expected_proto = GraphDef()
text_format.Parse(expected_str, expected_proto)
self.assertEquals(len(expected_proto.node), len(actual_proto.node))
for i in range(len(expected_proto.node)):
expected_node = expected_proto.node[i]
actual_node = actual_proto.node[i]
self.assertEquals(expected_node.name, actual_node.name)
self.assertEquals(expected_node.op, actual_node.op)
self.assertEquals(expected_node.input, actual_node.input)
self.assertEquals(expected_node.device, actual_node.device)
self.assertEquals(
sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
@unittest.skipIf(torch.__version__ < "1.4.0", "not supported")
def test_graph_module1(self):
dummy_input = (torch.zeros(1, 3),)
class myLinear(torch.nn.Module):
def __init__(self):
super(myLinear, self).__init__()
self.l = torch.nn.Linear(3, 5)
def forward(self, x):
return self.l(x)
self._test_graph(
myLinear(),
dummy_input,
os.path.join(os.path.dirname(__file__), "expect", "test_graph_module1.expect")
)
@unittest.skipIf(torch.__version__ < "1.4.0", "not supported")
def test_graph_module2(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Linear(5, 3)
self.bias = nn.Linear(5, 3)
self.module = nn.Linear(6, 1)
def forward(self, x):
tensors = [self.weight(x), self.bias(x)]
self.module(torch.cat(tensors, dim=1))
return x
self._test_graph(
MyModule(),
torch.randn(4, 5),
os.path.join(os.path.dirname(__file__), "expect", "test_graph_module2.expect")
)
@unittest.skipIf(torch.__version__ < "1.4.0", "not supported")
def test_graph_module3(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.module = nn.ModuleList([
nn.Linear(5, 3),
nn.Linear(3, 1)
])
def forward(self, x):
x = self.module[0](x)
x = self.module[1](x)
return x
self._test_graph(
MyModule(),
torch.randn(4, 5),
os.path.join(os.path.dirname(__file__), "expect", "test_graph_module3.expect")
)
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.vgg import vgg16
from torchvision.models.resnet import resnet18
from unittest import TestCase, main
from nni.compression.torch import L1FilterPruner
from nni.compression.speedup.torch import ModelSpeedup
class BackboneModel1(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 1, 1, 1)
def forward(self, x):
return self.conv1(x)
class BackboneModel2(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class BigModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.backbone1 = BackboneModel1()
self.backbone2 = BackboneModel2()
self.fc3 = nn.Sequential(
nn.Linear(10, 10),
nn.BatchNorm1d(10),
nn.ReLU(inplace=True),
nn.Linear(10, 2)
)
def forward(self, x):
x = self.backbone1(x)
x = self.backbone2(x)
x = self.fc3(x)
return x
SPARSITY = 0.5
def prune_model_l1(model):
config_list = [{
'sparsity': SPARSITY,
'op_types': ['Conv2d']
}]
pruner = L1FilterPruner(model, config_list)
pruner.compress()
pruner.export_model(model_path='./11_model.pth', mask_path='./l1_mask.pth')
class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self):
prune_model_l1(vgg16())
model = vgg16()
model.train()
ms = ModelSpeedup(model, torch.randn(2, 3, 32, 32), './l1_mask.pth')
ms.speedup_model()
orig_model = vgg16()
assert model.training
assert model.features[2].out_channels == int(orig_model.features[2].out_channels * SPARSITY)
assert model.classifier[0].in_features == int(orig_model.classifier[0].in_features * SPARSITY)
#def test_speedup_resnet(self):
#TODO support resnet
#model = resnet18()
def test_speedup_bigmodel(self):
prune_model_l1(BigModel())
model = BigModel()
model.train()
ms = ModelSpeedup(model, torch.randn(2, 1, 28, 28), './l1_mask.pth')
ms.speedup_model()
orig_model = BigModel()
assert model.training
assert model.backbone2.conv1.out_channels == int(orig_model.backbone2.conv1.out_channels * SPARSITY)
assert model.backbone2.conv2.in_channels == int(orig_model.backbone2.conv2.in_channels * SPARSITY)
assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY)
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)
def tearDown(self):
os.remove('./11_model.pth')
os.remove('./l1_mask.pth')
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment