"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "37e194f48a56723b4bc8d9e9674236cc7f90db3c"
Unverified Commit 9a68cdb2 authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

Fix: refine shape attribute (#4214)

parent 50dc05d7
...@@ -707,19 +707,21 @@ class GraphConverterWithShape(GraphConverter): ...@@ -707,19 +707,21 @@ class GraphConverterWithShape(GraphConverter):
for ir_node in ir_model.get_nodes(): for ir_node in ir_model.get_nodes():
if ir_node.operation.parameters is None: if ir_node.operation.parameters is None:
ir_node.operation.parameters = {} ir_node.operation.parameters = {}
ir_node.operation.parameters.setdefault('input_shape', []) ir_node.operation.attributes.setdefault('input_shape', [])
ir_node.operation.parameters.setdefault('output_shape', []) ir_node.operation.attributes.setdefault('output_shape', [])
def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input): def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input):
# First, trace the whole graph # First, trace the whole graph
tm_graph = self._trace(module, dummy_input) tm_graph = self._trace(module, dummy_input)
for node in tm_graph.nodes(): for node in tm_graph.nodes():
parameters = _extract_info_from_trace_node(node) shape_parameters, parameters = _extract_info_from_trace_node(node)
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv' # '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
ir_node = match_node(ir_model, node, module_name) ir_node = match_node(ir_model, node, module_name)
if ir_node is not None: if ir_node is not None:
ir_node.operation.parameters.update(parameters) ir_node.operation.attributes.update(shape_parameters)
if parameters:
ir_node.operation.parameters.update(parameters)
self.propagate_shape(ir_model) self.propagate_shape(ir_model)
...@@ -735,7 +737,7 @@ class GraphConverterWithShape(GraphConverter): ...@@ -735,7 +737,7 @@ class GraphConverterWithShape(GraphConverter):
cand_name = build_cand_name(cand_name, submodule.label) cand_name = build_cand_name(cand_name, submodule.label)
# TODO: Feed the exact input tensor if user provides input, # TODO: Feed the exact input tensor if user provides input,
# in case the path changes according to input data. # in case the path changes according to input data.
lc_inputs = [torch.randn(shape) for shape in lc_node.operation.parameters['input_shape']] lc_inputs = [torch.randn(shape) for shape in lc_node.operation.attributes['input_shape']]
self._trace_module(cand, cand_name, ir_model, lc_inputs) self._trace_module(cand, cand_name, ir_model, lc_inputs)
def propagate_shape(self, ir_model: 'Model'): def propagate_shape(self, ir_model: 'Model'):
...@@ -753,8 +755,8 @@ class GraphConverterWithShape(GraphConverter): ...@@ -753,8 +755,8 @@ class GraphConverterWithShape(GraphConverter):
cand_node = ir_model.get_node_by_name(cand_name) cand_node = ir_model.get_node_by_name(cand_name)
if _without_shape_info(cand_node): if _without_shape_info(cand_node):
propagate_shape_for_graph(ir_model.graphs[cand_name]) propagate_shape_for_graph(ir_model.graphs[cand_name])
graph_node.operation.parameters['input_shape'] = cand_node.operation.parameters['input_shape'] graph_node.operation.attributes['input_shape'] = cand_node.operation.attributes['input_shape']
graph_node.operation.parameters['output_shape'] = cand_node.operation.parameters['output_shape'] graph_node.operation.attributes['output_shape'] = cand_node.operation.attributes['output_shape']
else: else:
input_shape = [[]] * len(graph.input_node.operation.io_names or []) input_shape = [[]] * len(graph.input_node.operation.io_names or [])
output_shape = [[]] * len(graph.output_node.operation.io_names or []) output_shape = [[]] * len(graph.output_node.operation.io_names or [])
...@@ -763,17 +765,17 @@ class GraphConverterWithShape(GraphConverter): ...@@ -763,17 +765,17 @@ class GraphConverterWithShape(GraphConverter):
if _without_shape_info(node): if _without_shape_info(node):
if node.name in ir_model.graphs: if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name]) propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.parameters['input_shape']: if node.operation.attributes['input_shape']:
input_shape[edge.head_slot or 0] = node.operation.parameters['input_shape'][edge.tail_slot or 0] input_shape[edge.head_slot or 0] = node.operation.attributes['input_shape'][edge.tail_slot or 0]
graph_node.operation.parameters['input_shape'] = input_shape graph_node.operation.attributes['input_shape'] = input_shape
for edge in graph.output_node.incoming_edges: for edge in graph.output_node.incoming_edges:
node = edge.head node = edge.head
if _without_shape_info(node): if _without_shape_info(node):
if node.name in ir_model.graphs: if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name]) propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.parameters['output_shape']: if node.operation.attributes['output_shape']:
output_shape[edge.tail_slot or 0] = node.operation.parameters['output_shape'][edge.head_slot or 0] output_shape[edge.tail_slot or 0] = node.operation.attributes['output_shape'][edge.head_slot or 0]
graph_node.operation.parameters['output_shape'] = output_shape graph_node.operation.attributes['output_shape'] = output_shape
propagate_shape_for_graph(graph_node.graph) propagate_shape_for_graph(graph_node.graph)
......
...@@ -56,15 +56,16 @@ def _extract_info_from_trace_node(trace_node): ...@@ -56,15 +56,16 @@ def _extract_info_from_trace_node(trace_node):
if shape: if shape:
output_shape.append(shape) output_shape.append(shape)
parameters = { shape_parameters = {
'input_shape': input_shape, 'input_shape': input_shape,
'output_shape': output_shape, 'output_shape': output_shape,
} }
if trace_node.kind() == 'aten::cat': if trace_node.kind() == 'aten::cat':
parameters['dim'] = inputs[1].toIValue() parameters = {'dim': inputs[1].toIValue()}
return shape_parameters, parameters
return parameters else:
return shape_parameters, None
def is_layerchoice_node(ir_node: Node): def is_layerchoice_node(ir_node: Node):
...@@ -100,7 +101,7 @@ def match_node(ir_model: Model, torch_node, prefix=''): ...@@ -100,7 +101,7 @@ def match_node(ir_model: Model, torch_node, prefix=''):
graph = ir_model.graphs.get(full_name) graph = ir_model.graphs.get(full_name)
if graph is not None: if graph is not None:
for node in graph.get_nodes_by_type(torch_node.kind()): for node in graph.get_nodes_by_type(torch_node.kind()):
if not node.operation.parameters['input_shape']: if not node.operation.attributes['input_shape']:
return node return node
return None return None
else: else:
...@@ -108,4 +109,4 @@ def match_node(ir_model: Model, torch_node, prefix=''): ...@@ -108,4 +109,4 @@ def match_node(ir_model: Model, torch_node, prefix=''):
def _without_shape_info(node: Node): def _without_shape_info(node: Node):
return not node.operation.parameters['input_shape'] and not node.operation.parameters['output_shape'] return not node.operation.attributes['input_shape'] and not node.operation.attributes['output_shape']
...@@ -603,16 +603,18 @@ class Node: ...@@ -603,16 +603,18 @@ class Node:
@staticmethod @staticmethod
def _load(graph: Graph, name: str, ir: Any) -> 'Node': def _load(graph: Graph, name: str, ir: Any) -> 'Node':
if ir['operation']['type'] == '_cell': if ir['operation']['type'] == '_cell':
op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {})) op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {}), attributes=ir['operation'].get('attributes', {}))
else: else:
op = Operation.new(ir['operation']['type'], ir['operation'].get('parameters', {})) op = Operation.new(ir['operation']['type'],
ir['operation'].get('parameters', {}),
attributes=ir['operation'].get('attributes', {}))
node = Node(graph, uid(), name, op) node = Node(graph, uid(), name, op)
if 'label' in ir: if 'label' in ir:
node.update_label(ir['label']) node.update_label(ir['label'])
return node return node
def _dump(self) -> Any: def _dump(self) -> Any:
ret = {'operation': {'type': self.operation.type, 'parameters': self.operation.parameters}} ret = {'operation': {'type': self.operation.type, 'parameters': self.operation.parameters, 'attributes': self.operation.attributes}}
if isinstance(self.operation, Cell): if isinstance(self.operation, Cell):
ret['operation']['cell_name'] = self.operation.cell_name ret['operation']['cell_name'] = self.operation.cell_name
if self.label is not None: if self.label is not None:
......
...@@ -34,10 +34,11 @@ class Operation: ...@@ -34,10 +34,11 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size). Arbitrary key-value parameters (e.g. kernel_size).
""" """
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False): def __init__(self, type_name: str, parameters: Dict[str, Any] = {}, _internal: bool = False, attributes: Dict[str, Any] = {}):
assert _internal, '`Operation()` is private, use `Operation.new()` instead' assert _internal, '`Operation()` is private, use `Operation.new()` instead'
self.type: str = type_name self.type: str = type_name
self.parameters: Dict[str, Any] = parameters self.parameters: Dict[str, Any] = parameters
self.attributes: Dict[str, Any] = attributes
def to_init_code(self, field: str) -> str: def to_init_code(self, field: str) -> str:
raise NotImplementedError() raise NotImplementedError()
...@@ -52,9 +53,10 @@ class Operation: ...@@ -52,9 +53,10 @@ class Operation:
return True return True
@staticmethod @staticmethod
def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None) -> 'Operation': def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None,
if parameters is None: attributes: Dict[str, Any] = None) -> 'Operation':
parameters = {} parameters = parameters or {}
attributes = attributes or {}
if type_name == '_cell': if type_name == '_cell':
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node # NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return Cell(cell_name, parameters) return Cell(cell_name, parameters)
...@@ -67,7 +69,7 @@ class Operation: ...@@ -67,7 +69,7 @@ class Operation:
cls = TensorFlowOperation._find_subclass(type_name) cls = TensorFlowOperation._find_subclass(type_name)
else: else:
raise ValueError(f'Unsupported framework: {debug_configs.framework}') raise ValueError(f'Unsupported framework: {debug_configs.framework}')
return cls(type_name, parameters, _internal=True) return cls(type_name, parameters, _internal=True, attributes=attributes)
@classmethod @classmethod
def _find_subclass(cls, subclass_name): def _find_subclass(cls, subclass_name):
...@@ -205,12 +207,11 @@ class Cell(PyTorchOperation): ...@@ -205,12 +207,11 @@ class Cell(PyTorchOperation):
No real usage. Exists for compatibility with base class. No real usage. Exists for compatibility with base class.
""" """
def __init__(self, cell_name: str, parameters: Dict[str, Any] = None): def __init__(self, cell_name: str, parameters: Dict[str, Any] = None, attributes: Dict[str, Any] = None):
self.type = '_cell' self.type = '_cell'
self.cell_name = cell_name self.cell_name = cell_name
if parameters is None: self.parameters = parameters or {}
parameters = {} self.attributes = attributes or {}
self.parameters = parameters
def _to_class_name(self): def _to_class_name(self):
# TODO: ugly, think about how to refactor this part # TODO: ugly, think about how to refactor this part
......
...@@ -5,7 +5,7 @@ from ..operation import TensorFlowOperation ...@@ -5,7 +5,7 @@ from ..operation import TensorFlowOperation
class Conv2D(TensorFlowOperation): class Conv2D(TensorFlowOperation):
def __init__(self, type_name, parameters, _internal): def __init__(self, type_name, parameters, _internal, attributes=None):
if 'padding' not in parameters: if 'padding' not in parameters:
parameters['padding'] = 'same' parameters['padding'] = 'same'
super().__init__(type_name, parameters, _internal) super().__init__(type_name, parameters, _internal)
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
"outputs": ["metric"], "outputs": ["metric"],
"nodes": { "nodes": {
"stem": {"operation": {"type": "_cell", "parameters": {}, "cell_name": "stem"}}, "stem": {"operation": {"type": "_cell", "parameters": {}, "attributes": {}, "cell_name": "stem"}},
"flatten": {"operation": {"type": "Flatten", "parameters": {}}}, "flatten": {"operation": {"type": "Flatten", "parameters": {}, "attributes": {}}},
"fc1": {"operation": {"type": "Dense", "parameters": {"units": 1024, "activation": "relu"}}}, "fc1": {"operation": {"type": "Dense", "parameters": {"units": 1024, "activation": "relu"}, "attributes": {}}},
"fc2": {"operation": {"type": "Dense", "parameters": {"units": 10}}}, "fc2": {"operation": {"type": "Dense", "parameters": {"units": 10}, "attributes": {}}},
"softmax": {"operation": {"type": "Softmax", "parameters": {}}} "softmax": {"operation": {"type": "Softmax", "parameters": {}, "attributes": {}}}
}, },
"edges": [ "edges": [
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
"stem": { "stem": {
"nodes": { "nodes": {
"conv1": {"operation": {"type": "Conv2D", "parameters": {"filters": 32, "kernel_size": 5, "activation": "relu"}}}, "conv1": {"operation": {"type": "Conv2D", "parameters": {"filters": 32, "kernel_size": 5, "activation": "relu"}, "attributes": {}}},
"pool1": {"operation": {"type": "MaxPool2D", "parameters": {"pool_size": 2}}}, "pool1": {"operation": {"type": "MaxPool2D", "parameters": {"pool_size": 2}, "attributes": {}}},
"conv2": {"operation": {"type": "Conv2D", "parameters": {"filters": 64, "kernel_size": 5, "activation": "relu"}}}, "conv2": {"operation": {"type": "Conv2D", "parameters": {"filters": 64, "kernel_size": 5, "activation": "relu"}, "attributes": {}}},
"pool2": {"operation": {"type": "MaxPool2D", "parameters": {"pool_size": 2}}} "pool2": {"operation": {"type": "MaxPool2D", "parameters": {"pool_size": 2}, "attributes": {}}}
}, },
"edges": [ "edges": [
......
...@@ -24,12 +24,12 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin): ...@@ -24,12 +24,12 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
conv_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')[0] conv_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')[0]
relu_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.activation.ReLU')[0] relu_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.activation.ReLU')[0]
pool_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.pooling.MaxPool2d')[0] pool_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.pooling.MaxPool2d')[0]
self.assertEqual(conv_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]]) self.assertEqual(conv_node.operation.attributes.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(conv_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]]) self.assertEqual(conv_node.operation.attributes.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]]) self.assertEqual(relu_node.operation.attributes.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]]) self.assertEqual(relu_node.operation.attributes.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]]) self.assertEqual(pool_node.operation.attributes.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.parameters.get('output_shape'), [[1, 1, 111, 111]]) self.assertEqual(pool_node.operation.attributes.get('output_shape'), [[1, 1, 111, 111]])
def test_nested_module(self): def test_nested_module(self):
class ConvRelu(nn.Module): class ConvRelu(nn.Module):
...@@ -54,8 +54,8 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin): ...@@ -54,8 +54,8 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
# check if shape propagation works # check if shape propagation works
cell_node = model_ir.get_nodes_by_type('_cell')[0] cell_node = model_ir.get_nodes_by_type('_cell')[0]
self.assertEqual(cell_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]]) self.assertEqual(cell_node.operation.attributes.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(cell_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]]) self.assertEqual(cell_node.operation.attributes.get('output_shape'), [[1, 1, 222, 222]])
def test_layerchoice(self): def test_layerchoice(self):
class ConvNet(nn.Module): class ConvNet(nn.Module):
...@@ -75,5 +75,5 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin): ...@@ -75,5 +75,5 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
# check shape info of each candidates # check shape info of each candidates
conv_nodes = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d') conv_nodes = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')
self.assertEqual(conv_nodes[0].operation.parameters.get('output_shape'), [[1, 1, 222, 222]]) self.assertEqual(conv_nodes[0].operation.attributes.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(conv_nodes[1].operation.parameters.get('output_shape'), [[1, 1, 222, 222]]) self.assertEqual(conv_nodes[1].operation.attributes.get('output_shape'), [[1, 1, 222, 222]])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment