Unverified Commit 7ee5036b authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Bugfix issue2485 (#2524)

parent e1e1977c
...@@ -8,19 +8,21 @@ import re ...@@ -8,19 +8,21 @@ import re
from collections import defaultdict from collections import defaultdict
import torch import torch
from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy
CLASSTYPE_KIND = 'ClassType' CLASSTYPE_KIND = 'ClassType'
GETATTR_KIND = 'prim::GetAttr' GETATTR_KIND = 'prim::GetAttr'
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def build_module_graph(model, dummy_input): def build_module_graph(model, dummy_input):
return TorchModuleGraph(model, dummy_input) return TorchModuleGraph(model, dummy_input)
def build_graph(model, dummy_input, verbose=False): def build_graph(model, dummy_input, verbose=False):
g = TorchProtoGraph(model, dummy_input, verbose) g = TorchProtoGraph(model, dummy_input, verbose)
return g.graph_def, g.stepstats return g.graph_def, g.stepstats
def parse_traced_name(module_name): def parse_traced_name(module_name):
prefix = 'TracedModule[' prefix = 'TracedModule['
suffix = ']' suffix = ']'
...@@ -28,11 +30,13 @@ def parse_traced_name(module_name): ...@@ -28,11 +30,13 @@ def parse_traced_name(module_name):
module_name = module_name[len(prefix):-len(suffix)] module_name = module_name[len(prefix):-len(suffix)]
return module_name return module_name
class TorchGraph: class TorchGraph:
""" """
This class is to extract pytorch model topology graph by tracing This class is to extract pytorch model topology graph by tracing
""" """
def __init__(self, model, dummy_input):
def __init__(self, model=None, dummy_input=None, traced_model=None):
""" """
Parameters Parameters
---------- ----------
...@@ -40,25 +44,39 @@ class TorchGraph: ...@@ -40,25 +44,39 @@ class TorchGraph:
The model user wants to speed up The model user wants to speed up
dummy_input : pytorch tensor dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in The dummy input for ```jit.trace```, users should put it on right device before pass in
traced_model : torch._C.torch.jit.TopLevelTracedModule
An alredy traced model, if traced_model is not None, then TorchGraph will build the graph
based on this traced model and won't trace the model again.
""" """
assert torch.__version__ >= '1.3.1' assert torch.__version__ >= '1.3.1'
# check if the input is legal
self.bound_model = model if traced_model is not None:
self._trace(model, dummy_input) assert isinstance(traced_model, torch.jit.TopLevelTracedModule)
self.trace = traced_model
# it's ok if the graph is already unpacked
torch._C._jit_pass_inline(self.trace.graph)
elif model is not None and dummy_input is not None:
self.bound_model = model
self._trace(model, dummy_input)
else:
raise Exception(
'Please provide model & dummy_input or the traced_model as inputs')
def _trace(self, model, dummy_input): def _trace(self, model, dummy_input):
with torch.onnx.set_training(model, False): with torch.onnx.set_training(model, False):
self.trace = torch.jit.trace(model, dummy_input) self.trace = torch.jit.trace(model, dummy_input)
torch._C._jit_pass_inline(self.trace.graph) torch._C._jit_pass_inline(self.trace.graph)
class TorchProtoGraph(TorchGraph): class TorchProtoGraph(TorchGraph):
""" """
Generates model graph for pytorch models in protobuf, this implementation is borrowed from pytorch v1.4.0, Generates model graph for pytorch models in protobuf, this implementation
and fixed following issues: is borrowed from pytorch v1.4.0, and fixed following issues:
https://github.com/pytorch/pytorch/issues/33691 https://github.com/pytorch/pytorch/issues/33691
https://github.com/pytorch/pytorch/issues/33670 https://github.com/pytorch/pytorch/issues/33670
""" """
def __init__(self, model, dummy_input, verbose=False): def __init__(self, model, dummy_input, verbose=False):
super().__init__(model, dummy_input) super().__init__(model, dummy_input)
...@@ -70,8 +88,10 @@ class TorchProtoGraph(TorchGraph): ...@@ -70,8 +88,10 @@ class TorchProtoGraph(TorchGraph):
list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input) list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input)
if verbose: if verbose:
print(self.trace.graph) print(self.trace.graph)
self.stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])) self.stepstats = RunMetadata(step_stats=StepStats(
self.graph_def = GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)) dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
self.graph_def = GraphDef(
node=list_of_nodes, versions=VersionDef(producer=22))
def parse(self, graph, trace, args=None, omit_useless_nodes=True): def parse(self, graph, trace, args=None, omit_useless_nodes=True):
"""This method parses an optimized PyTorch model graph and produces """This method parses an optimized PyTorch model graph and produces
...@@ -94,16 +114,20 @@ class TorchProtoGraph(TorchGraph): ...@@ -94,16 +114,20 @@ class TorchProtoGraph(TorchGraph):
nodes_py.append(NodePyIO(node, 'input')) nodes_py.append(NodePyIO(node, 'input'))
attr_to_scope = dict() attr_to_scope = dict()
node_to_name = lambda d: str(d).split(":")[0].strip()
def node_to_name(d):
return str(d).split(":")[0].strip()
for node in graph.nodes(): for node in graph.nodes():
if node.kind() == GETATTR_KIND: if node.kind() == GETATTR_KIND:
attr_name = node.s('name') attr_name = node.s('name')
node_name = node_to_name(node) node_name = node_to_name(node)
parent = node.input().node() parent = node.input().node()
if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node # If the parent node is not the top-level "self" node
if parent.kind() == GETATTR_KIND:
parent_scope = attr_to_scope[node_to_name(parent)] parent_scope = attr_to_scope[node_to_name(parent)]
attr_scope = parent_scope.split('/')[-1] attr_scope = parent_scope.split('/')[-1]
attr_to_scope[node_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name) attr_to_scope[node_name] = '{}/{}.{}'.format(
parent_scope, attr_scope, attr_name)
else: else:
attr_to_scope[node_name] = '__module.{}'.format(attr_name) attr_to_scope[node_name] = '__module.{}'.format(attr_name)
# We don't need classtype nodes; scope will provide this information # We don't need classtype nodes; scope will provide this information
...@@ -114,7 +138,8 @@ class TorchProtoGraph(TorchGraph): ...@@ -114,7 +138,8 @@ class TorchProtoGraph(TorchGraph):
else: else:
nodes_py.append(NodePyOP(node)) nodes_py.append(NodePyOP(node))
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops # Create sink nodes for output ops
for i, node in enumerate(graph.outputs()):
node_py = NodePyIO(node, 'output') node_py = NodePyIO(node, 'output')
node_py.debugName = "output.{}".format(i + 1) node_py.debugName = "output.{}".format(i + 1)
node_py.inputs = [node.debugName()] node_py.inputs = [node.debugName()]
...@@ -136,23 +161,33 @@ class TorchProtoGraph(TorchGraph): ...@@ -136,23 +161,33 @@ class TorchProtoGraph(TorchGraph):
node.scopeName = base_name node.scopeName = base_name
else: else:
module_name += '.' + alias module_name += '.' + alias
node.scopeName += '/' + (alias_to_name[module_name] if module_name in alias_to_name else alias) node.scopeName += '/' + \
(alias_to_name[module_name]
if module_name in alias_to_name else alias)
nodes_py.populate_namespace_from_OP_to_IO() nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto() return nodes_py.to_proto()
class NodePyGroup(NodePy): class NodePyGroup(NodePy):
""" """
This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph, This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph,
there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node. represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
""" """
def __init__(self, name, node_type, op_type, node_cpps, inputs=None, outputs=None):
def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None):
""" """
Parameters: Parameters:
----------- -----------
name: str name: str
node name, such as `conv1`, `backbone.classifier` node name, such as `conv1`, `backbone.classifier`
unique_name: str
A global unique name for current node. Due to some modules,
such as relu, may be reused several times, so the scopename
is not suitable as the global unique identifier, so we add a
unique_name for each node as the global unique identifier.
We should use the unique_name to traverset the module graph.
node_type: str node_type: str
`module` or `func` `module` or `func`
op_type: str op_type: str
...@@ -167,6 +202,7 @@ class NodePyGroup(NodePy): ...@@ -167,6 +202,7 @@ class NodePyGroup(NodePy):
super(NodePyGroup, self).__init__(name, []) super(NodePyGroup, self).__init__(name, [])
self.node_cpps = node_cpps self.node_cpps = node_cpps
self.name = name self.name = name
self.unique_name = unique_name
self.op_type = op_type self.op_type = op_type
self.type = node_type self.type = node_type
self.nodes = [] self.nodes = []
...@@ -178,7 +214,7 @@ class NodePyGroup(NodePy): ...@@ -178,7 +214,7 @@ class NodePyGroup(NodePy):
def add_nodes(self, node_cpps): def add_nodes(self, node_cpps):
for node_cpp in node_cpps: for node_cpp in node_cpps:
nodepy = NodePyOP(node_cpp) nodepy = NodePyOP(node_cpp)
nodepy.name = str(node_cpp).split(':')[0].strip().replace('%', '') nodepy.name = node_cpp.scopeName() + '_' + node_cpp.kind()
self.nodes.append(nodepy) self.nodes.append(nodepy)
def sub_node_names(self): def sub_node_names(self):
...@@ -186,7 +222,8 @@ class NodePyGroup(NodePy): ...@@ -186,7 +222,8 @@ class NodePyGroup(NodePy):
def __repr__(self): def __repr__(self):
return 'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'.format( return 'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'.format(
self.name, self.type, self.op_type, self.sub_node_names(), self.inputs, self.outputs, self.auxiliary self.name, self.type, self.op_type, self.sub_node_names(),
self.inputs, self.outputs, self.auxiliary
) )
...@@ -194,12 +231,14 @@ class TorchModuleGraph(TorchGraph): ...@@ -194,12 +231,14 @@ class TorchModuleGraph(TorchGraph):
""" """
Generates model graph, each node is created from single or multiple jit trace nodes. Generates model graph, each node is created from single or multiple jit trace nodes.
""" """
def __init__(self, model, dummy_input):
super().__init__(model, dummy_input) def __init__(self, model=None, dummy_input=None, traced_model=None):
super().__init__(model, dummy_input, traced_model)
self.global_count = 0 self.global_count = 0
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node): def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node,
module_type):
""" """
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are the functions directly called in module ```forward```. For such nodes, some of them are
...@@ -217,6 +256,8 @@ class TorchModuleGraph(TorchGraph): ...@@ -217,6 +256,8 @@ class TorchModuleGraph(TorchGraph):
key: input name, value: a node that uses this input key: input name, value: a node that uses this input
output_to_node : dict output_to_node : dict
key: output name, value: a node that generates this output key: output name, value: a node that generates this output
module_type : str
can be 'module' or 'func'
Returns Returns
------- -------
...@@ -224,11 +265,12 @@ class TorchModuleGraph(TorchGraph): ...@@ -224,11 +265,12 @@ class TorchModuleGraph(TorchGraph):
the expanded non-prim node the expanded non-prim node
""" """
# TODO: scope name could be empty # TODO: scope name could be empty
node_name = '.'.join([self._get_module_name(node.scopeName()), node.kind(), str(self.global_count)]) node_name = '.'.join([self._get_module_name(
node.scopeName()), node.kind(), str(self.global_count)])
unique_name = node_name
_logger.debug("expand non-prim node, node name: %s", node_name) _logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1 self.global_count += 1
op_type = node.kind() op_type = node.kind()
node_group = [node] node_group = [node]
inputs = list() inputs = list()
outputs = list() outputs = list()
...@@ -239,38 +281,88 @@ class TorchModuleGraph(TorchGraph): ...@@ -239,38 +281,88 @@ class TorchModuleGraph(TorchGraph):
for _input in curr_node.inputs(): for _input in curr_node.inputs():
input_name = _input.debugName() input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes: if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name] predecessor_node = output_to_node[input_name]
if predecessor_node.kind().startswith('prim::'): if predecessor_node.kind().startswith('prim::'):
node_group.append(predecessor_node) node_group.append(predecessor_node)
node_queue.put(predecessor_node) node_queue.put(predecessor_node)
else: else:
inputs.append(input_name) inputs.append(input_name)
else: else:
inputs.append(input_name) inputs.append(input_name)
for output in node.outputs(): for output in node.outputs():
outputs.append(output.debugName()) outputs.append(output.debugName())
nodepy = NodePyGroup(node_name, 'func', op_type, node_group, inputs=inputs, outputs=outputs) nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs)
return nodepy return nodepy
def _build_module_node_group(self, module_name, op_type, node_cpps, input_to_node, output_to_node): def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
graph = self.trace.graph input_to_node, output_to_node, module_type):
inputs, outputs = [], [] """
for n in node_cpps: merge the adjacent nodes of the module. The difference between the
for i in n.inputs(): _expand_module_node and _expand_non_prim_node is that, the _expand_non_prim_node
name = i.debugName() only merge the prim:: nodes into the aten:: node, in contrast,the _expand_module_node
if not name in output_to_node and i in graph.inputs(): will merge all adjacent nodes into a same nodepy group.
inputs.append(name)
elif output_to_node[name] not in node_cpps:
inputs.append(name)
for o in n.outputs():
name = o.debugName()
if not name in input_to_node and o in graph.outputs():
outputs.append(name)
elif input_to_node[name] not in node_cpps:
outputs.append(name)
return NodePyGroup(module_name, 'module', op_type, node_cpps, inputs, outputs)
Parameters
----------
node : trace graph node
The non-prim node to expand
node_name : str
specify the node_name for NodePyGroup
unique_name : str
unique_name for the NodePyGroup
op_type : str
specify the op_type for the NodePyGroup
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
module_type : str
can be 'module' or 'func'
Returns
-------
node
the expanded non-prim node
"""
_logger.debug("expand module node, node name: %s", node_name)
self.global_count += 1
if not op_type:
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
node_queue = queue.Queue()
node_queue.put(node)
visited = {node}
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if predecessor_node not in visited:
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
visited.add(predecessor_node)
else:
inputs.append(input_name)
for _output in curr_node.outputs():
output_name = _output.debugName()
if output_name in input_to_node and input_to_node[output_name] in nodes:
successor_node = input_to_node[output_name]
if successor_node not in visited:
node_group.append(successor_node)
node_queue.put(successor_node)
visited.add(successor_node)
else:
outputs.append(output_name)
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs)
return nodepy
def _extract_shape_info(self, node): def _extract_shape_info(self, node):
""" """
...@@ -318,11 +410,12 @@ class TorchModuleGraph(TorchGraph): ...@@ -318,11 +410,12 @@ class TorchModuleGraph(TorchGraph):
parts1, parts2 = name1.split('.'), name2.split('.') parts1, parts2 = name1.split('.'), name2.split('.')
if len(parts1) >= len(parts2): if len(parts1) >= len(parts2):
return False return False
for i in range(len(parts1)): for i, _ in enumerate(parts1):
if parts2[i] != parts1[i]: if parts2[i] != parts1[i]:
return False return False
return True return True
module_names = sorted([x[0] for x in self.trace.named_modules() if x[0]]) module_names = sorted([x[0]
for x in self.trace.named_modules() if x[0]])
leaf_nodes = [] leaf_nodes = []
for i, name in enumerate(module_names): for i, name in enumerate(module_names):
if i + 1 >= len(module_names) or not is_parent(name, module_names[i + 1]): if i + 1 >= len(module_names) or not is_parent(name, module_names[i + 1]):
...@@ -354,7 +447,7 @@ class TorchModuleGraph(TorchGraph): ...@@ -354,7 +447,7 @@ class TorchModuleGraph(TorchGraph):
input_to_node = defaultdict(list) input_to_node = defaultdict(list)
output_to_node = dict() output_to_node = dict()
for node in nodes_op: for node in nodes_op:
name_to_node[node.name] = node name_to_node[node.unique_name] = node
for _input in node.inputs: for _input in node.inputs:
input_to_node[_input].append(node) input_to_node[_input].append(node)
for output in node.outputs: for output in node.outputs:
...@@ -385,9 +478,11 @@ class TorchModuleGraph(TorchGraph): ...@@ -385,9 +478,11 @@ class TorchModuleGraph(TorchGraph):
graph = self.trace.graph graph = self.trace.graph
_logger.debug(graph) _logger.debug(graph)
# build output mapping, from output debugName to its node # build output mapping, from output debugName to its node
output_to_node = {x.debugName(): n for n in graph.nodes() for x in n.outputs()} output_to_node = {x.debugName(): n for n in graph.nodes()
for x in n.outputs()}
# build input mapping, from input debugName to its node # build input mapping, from input debugName to its node
input_to_node = {x.debugName(): n for n in graph.nodes() for x in n.inputs()} input_to_node = {x.debugName(): n for n in graph.nodes()
for x in n.inputs()}
# build module mapping, from module name to all nodes (as list) under this module scope # build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = defaultdict(list) module_to_nodes = defaultdict(list)
# the mapping of function (non-module in forward) to nodes, key is scope name # the mapping of function (non-module in forward) to nodes, key is scope name
...@@ -403,7 +498,8 @@ class TorchModuleGraph(TorchGraph): ...@@ -403,7 +498,8 @@ class TorchModuleGraph(TorchGraph):
nodes_py.append(NodePyIO(node, 'input')) nodes_py.append(NodePyIO(node, 'input'))
self.leaf_modules = self._extract_leaf_modules() self.leaf_modules = self._extract_leaf_modules()
module_to_type = {name: parse_traced_name(module._name) for name, module in self.trace.named_modules()} module_to_type = {name: parse_traced_name(
module._name) for name, module in self.trace.named_modules()}
# associate module name with their trace graph nodes # associate module name with their trace graph nodes
for node in graph.nodes(): for node in graph.nodes():
...@@ -412,14 +508,24 @@ class TorchModuleGraph(TorchGraph): ...@@ -412,14 +508,24 @@ class TorchModuleGraph(TorchGraph):
module_to_nodes[module_name].append(node) module_to_nodes[module_name].append(node)
else: else:
func_to_nodes[node.scopeName()].append(node) func_to_nodes[node.scopeName()].append(node)
# build node group for module # build node group for module
for module_name, node_cpps in module_to_nodes.items(): for module_name, node_cpps in module_to_nodes.items():
node_group = self._build_module_node_group( use_count = 0
module_name, module_to_type[module_name], node_cpps, input_to_node, output_to_node merged = set()
) for node in node_cpps:
_logger.debug('node_group: %s', node_group) if node not in merged:
nodes_py.nodes_op.append(node_group) # modules that have same scope name may have different locations in the
# graph. Futhermore, there are also lots of prim:: nodes that in node_cpps,
# so we also need to call the expand_module_node.
unique_name = module_name
if use_count > 0:
unique_name = module_name + '.%d' % use_count
node_group = self._expand_module_node(
node, module_name, unique_name, module_to_type[module_name],
node_cpps, input_to_node, output_to_node, 'module')
nodes_py.nodes_op.append(node_group)
use_count += 1
merged.update(node_group.node_cpps)
# each scope_name may have multiple funcs, we split them and create node for each of them # each scope_name may have multiple funcs, we split them and create node for each of them
# build node group for torch.nn.functional # build node group for torch.nn.functional
...@@ -431,11 +537,13 @@ class TorchModuleGraph(TorchGraph): ...@@ -431,11 +537,13 @@ class TorchModuleGraph(TorchGraph):
non_prim_nodes.append(node) non_prim_nodes.append(node)
# for each non prim node, expand it # for each non prim node, expand it
for node in non_prim_nodes: for node in non_prim_nodes:
node_group = self._expand_non_prim_node(node, nodes, input_to_node, output_to_node) node_group = self._expand_non_prim_node(
node, nodes, input_to_node, output_to_node, 'func')
nodes_py.nodes_op.append(node_group) nodes_py.nodes_op.append(node_group)
# get shape infor for view (aten::view) func # get shape infor for view (aten::view) func
if node_group.op_type in ['aten::view', 'aten::flatten']: if node_group.op_type in ['aten::view', 'aten::flatten']:
node_group.auxiliary = self._extract_shape_info(node) node_group.auxiliary = self._extract_shape_info(node)
for node in graph.outputs(): # Create sink nodes for output ops for node in graph.outputs(): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output') node_py = NodePyIO(node, 'output')
nodes_py.append(node_py) nodes_py.append(node_py)
...@@ -444,14 +552,14 @@ class TorchModuleGraph(TorchGraph): ...@@ -444,14 +552,14 @@ class TorchModuleGraph(TorchGraph):
# build index # build index
return self._build_index(self.nodes_py.nodes_op) return self._build_index(self.nodes_py.nodes_op)
def find_predecessors(self, module_name): def find_predecessors(self, unique_name):
""" """
Find predecessor node of the given node Find predecessor node of the given node
Parameters Parameters
---------- ----------
module_name : str unique_name : str
The name of the node The unique name of the node
Returns Returns
------- -------
...@@ -459,22 +567,22 @@ class TorchModuleGraph(TorchGraph): ...@@ -459,22 +567,22 @@ class TorchModuleGraph(TorchGraph):
a list of nodes who are the given node's predecessor a list of nodes who are the given node's predecessor
""" """
predecessors = [] predecessors = []
for _input in self.name_to_node[module_name].inputs: for _input in self.name_to_node[unique_name].inputs:
if not _input in self.output_to_node: if not _input in self.output_to_node:
_logger.debug("cannot find node with %s as its output", _input) _logger.debug("cannot find node with %s as its output", _input)
else: else:
node_py = self.output_to_node[_input] node_py = self.output_to_node[_input]
predecessors.append(node_py.name) predecessors.append(node_py.unique_name)
return predecessors return predecessors
def find_successors(self, module_name): def find_successors(self, unique_name):
""" """
Find successor nodes of the given node Find successor nodes of the given node
Parameters Parameters
---------- ----------
module_name : str unique_name : str
The name of the node The unique name of the node
Returns Returns
------- -------
...@@ -482,9 +590,11 @@ class TorchModuleGraph(TorchGraph): ...@@ -482,9 +590,11 @@ class TorchModuleGraph(TorchGraph):
a list of nodes who are the given node's successor a list of nodes who are the given node's successor
""" """
successors = [] successors = []
for output in self.name_to_node[module_name].outputs: for output in self.name_to_node[unique_name].outputs:
assert output in self.input_to_node, "No node with input {}".format(output) if output not in self.input_to_node:
# may reach the output of the whole graph
continue
nodes_py = self.input_to_node[output] nodes_py = self.input_to_node[output]
for node_py in nodes_py: for node_py in nodes_py:
successors.append(node_py.name) successors.append(node_py.unique_name)
return successors return successors
...@@ -15,7 +15,7 @@ from google.protobuf import text_format ...@@ -15,7 +15,7 @@ from google.protobuf import text_format
import unittest import unittest
from unittest import TestCase, main from unittest import TestCase, main
from nni._graph_utils import build_module_graph, build_graph from nni._graph_utils import build_module_graph, build_graph, TorchModuleGraph
class BackboneModel1(nn.Module): class BackboneModel1(nn.Module):
def __init__(self): def __init__(self):
...@@ -153,6 +153,46 @@ class GraphUtilsTestCase(TestCase): ...@@ -153,6 +153,46 @@ class GraphUtilsTestCase(TestCase):
torch.randn(4, 5), torch.randn(4, 5),
os.path.join(os.path.dirname(__file__), "expect", "test_graph_module3.expect") os.path.join(os.path.dirname(__file__), "expect", "test_graph_module3.expect")
) )
@unittest.skipIf(torch.__version__ < "1.4.0", "not supported")
def test_module_reuse(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.liner1 = nn.Linear(10, 10)
self.relu = nn.ReLU(inplace=True)
self.liner2 = nn.Linear(10, 20)
self.liner3 = nn.Linear(20, 10)
def forward(self, x):
x = self.liner1(x)
x = self.relu(x)
x = self.liner2(x)
x = self.relu(x)
x = self.liner3(x)
x = self.relu(x)
return x
data = torch.rand(10, 10)
net = MyModule()
traced = torch.jit.trace(net, data)
modulegraph = TorchModuleGraph(traced_model=traced)
# Traverse the TorchModuleGraph, due the resue of the relu module,
# there will be three cpp_nodes corrspoding to the same module.
# During traversing the graph, there should be only one
# successor of each cpp-node (including the cpp_nodes that corresponds
# to the same relu module).
for name, nodeio in modulegraph.nodes_py.nodes_io.items():
if nodeio.input_or_output == 'input':
# Find the first node of the whole graph
start_nodes = modulegraph.input_to_node[name]
# We have only one single path top-down
assert len(start_nodes) == 1
node = start_nodes[0].unique_name
while modulegraph.find_successors(node):
nodes = modulegraph.find_successors(node)
assert len(nodes) == 1
node = nodes[0]
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment