Unverified Commit 9a68cdb2 authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

Fix: refine shape attribute (#4214)

parent 50dc05d7
......@@ -707,18 +707,20 @@ class GraphConverterWithShape(GraphConverter):
for ir_node in ir_model.get_nodes():
if ir_node.operation.parameters is None:
ir_node.operation.parameters = {}
ir_node.operation.parameters.setdefault('input_shape', [])
ir_node.operation.parameters.setdefault('output_shape', [])
ir_node.operation.attributes.setdefault('input_shape', [])
ir_node.operation.attributes.setdefault('output_shape', [])
def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input):
# First, trace the whole graph
tm_graph = self._trace(module, dummy_input)
for node in tm_graph.nodes():
parameters = _extract_info_from_trace_node(node)
shape_parameters, parameters = _extract_info_from_trace_node(node)
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
ir_node = match_node(ir_model, node, module_name)
if ir_node is not None:
ir_node.operation.attributes.update(shape_parameters)
if parameters:
ir_node.operation.parameters.update(parameters)
self.propagate_shape(ir_model)
......@@ -735,7 +737,7 @@ class GraphConverterWithShape(GraphConverter):
cand_name = build_cand_name(cand_name, submodule.label)
# TODO: Feed the exact input tensor if user provides input,
# in case the path changes according to input data.
lc_inputs = [torch.randn(shape) for shape in lc_node.operation.parameters['input_shape']]
lc_inputs = [torch.randn(shape) for shape in lc_node.operation.attributes['input_shape']]
self._trace_module(cand, cand_name, ir_model, lc_inputs)
def propagate_shape(self, ir_model: 'Model'):
......@@ -753,8 +755,8 @@ class GraphConverterWithShape(GraphConverter):
cand_node = ir_model.get_node_by_name(cand_name)
if _without_shape_info(cand_node):
propagate_shape_for_graph(ir_model.graphs[cand_name])
graph_node.operation.parameters['input_shape'] = cand_node.operation.parameters['input_shape']
graph_node.operation.parameters['output_shape'] = cand_node.operation.parameters['output_shape']
graph_node.operation.attributes['input_shape'] = cand_node.operation.attributes['input_shape']
graph_node.operation.attributes['output_shape'] = cand_node.operation.attributes['output_shape']
else:
input_shape = [[]] * len(graph.input_node.operation.io_names or [])
output_shape = [[]] * len(graph.output_node.operation.io_names or [])
......@@ -763,17 +765,17 @@ class GraphConverterWithShape(GraphConverter):
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.parameters['input_shape']:
input_shape[edge.head_slot or 0] = node.operation.parameters['input_shape'][edge.tail_slot or 0]
graph_node.operation.parameters['input_shape'] = input_shape
if node.operation.attributes['input_shape']:
input_shape[edge.head_slot or 0] = node.operation.attributes['input_shape'][edge.tail_slot or 0]
graph_node.operation.attributes['input_shape'] = input_shape
for edge in graph.output_node.incoming_edges:
node = edge.head
if _without_shape_info(node):
if node.name in ir_model.graphs:
propagate_shape_for_graph(ir_model.graphs[node.name])
if node.operation.parameters['output_shape']:
output_shape[edge.tail_slot or 0] = node.operation.parameters['output_shape'][edge.head_slot or 0]
graph_node.operation.parameters['output_shape'] = output_shape
if node.operation.attributes['output_shape']:
output_shape[edge.tail_slot or 0] = node.operation.attributes['output_shape'][edge.head_slot or 0]
graph_node.operation.attributes['output_shape'] = output_shape
propagate_shape_for_graph(graph_node.graph)
......
......@@ -56,15 +56,16 @@ def _extract_info_from_trace_node(trace_node):
if shape:
output_shape.append(shape)
parameters = {
shape_parameters = {
'input_shape': input_shape,
'output_shape': output_shape,
}
if trace_node.kind() == 'aten::cat':
parameters['dim'] = inputs[1].toIValue()
return parameters
parameters = {'dim': inputs[1].toIValue()}
return shape_parameters, parameters
else:
return shape_parameters, None
def is_layerchoice_node(ir_node: Node):
......@@ -100,7 +101,7 @@ def match_node(ir_model: Model, torch_node, prefix=''):
graph = ir_model.graphs.get(full_name)
if graph is not None:
for node in graph.get_nodes_by_type(torch_node.kind()):
if not node.operation.parameters['input_shape']:
if not node.operation.attributes['input_shape']:
return node
return None
else:
......@@ -108,4 +109,4 @@ def match_node(ir_model: Model, torch_node, prefix=''):
def _without_shape_info(node: Node):
return not node.operation.parameters['input_shape'] and not node.operation.parameters['output_shape']
return not node.operation.attributes['input_shape'] and not node.operation.attributes['output_shape']
......@@ -603,16 +603,18 @@ class Node:
@staticmethod
def _load(graph: Graph, name: str, ir: Any) -> 'Node':
if ir['operation']['type'] == '_cell':
op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {}))
op = Cell(ir['operation']['cell_name'], ir['operation'].get('parameters', {}), attributes=ir['operation'].get('attributes', {}))
else:
op = Operation.new(ir['operation']['type'], ir['operation'].get('parameters', {}))
op = Operation.new(ir['operation']['type'],
ir['operation'].get('parameters', {}),
attributes=ir['operation'].get('attributes', {}))
node = Node(graph, uid(), name, op)
if 'label' in ir:
node.update_label(ir['label'])
return node
def _dump(self) -> Any:
ret = {'operation': {'type': self.operation.type, 'parameters': self.operation.parameters}}
ret = {'operation': {'type': self.operation.type, 'parameters': self.operation.parameters, 'attributes': self.operation.attributes}}
if isinstance(self.operation, Cell):
ret['operation']['cell_name'] = self.operation.cell_name
if self.label is not None:
......
......@@ -34,10 +34,11 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size).
"""
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False):
def __init__(self, type_name: str, parameters: Dict[str, Any] = {}, _internal: bool = False, attributes: Dict[str, Any] = {}):
assert _internal, '`Operation()` is private, use `Operation.new()` instead'
self.type: str = type_name
self.parameters: Dict[str, Any] = parameters
self.attributes: Dict[str, Any] = attributes
def to_init_code(self, field: str) -> str:
raise NotImplementedError()
......@@ -52,9 +53,10 @@ class Operation:
return True
@staticmethod
def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None) -> 'Operation':
if parameters is None:
parameters = {}
def new(type_name: str, parameters: Dict[str, Any] = None, cell_name: str = None,
attributes: Dict[str, Any] = None) -> 'Operation':
parameters = parameters or {}
attributes = attributes or {}
if type_name == '_cell':
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return Cell(cell_name, parameters)
......@@ -67,7 +69,7 @@ class Operation:
cls = TensorFlowOperation._find_subclass(type_name)
else:
raise ValueError(f'Unsupported framework: {debug_configs.framework}')
return cls(type_name, parameters, _internal=True)
return cls(type_name, parameters, _internal=True, attributes=attributes)
@classmethod
def _find_subclass(cls, subclass_name):
......@@ -205,12 +207,11 @@ class Cell(PyTorchOperation):
No real usage. Exists for compatibility with base class.
"""
def __init__(self, cell_name: str, parameters: Dict[str, Any] = None):
def __init__(self, cell_name: str, parameters: Dict[str, Any] = None, attributes: Dict[str, Any] = None):
self.type = '_cell'
self.cell_name = cell_name
if parameters is None:
parameters = {}
self.parameters = parameters
self.parameters = parameters or {}
self.attributes = attributes or {}
def _to_class_name(self):
# TODO: ugly, think about how to refactor this part
......
......@@ -5,7 +5,7 @@ from ..operation import TensorFlowOperation
class Conv2D(TensorFlowOperation):
def __init__(self, type_name, parameters, _internal):
def __init__(self, type_name, parameters, _internal, attributes=None):
if 'padding' not in parameters:
parameters['padding'] = 'same'
super().__init__(type_name, parameters, _internal)
......@@ -4,11 +4,11 @@
"outputs": ["metric"],
"nodes": {
"stem": {"operation": {"type": "_cell", "parameters": {}, "cell_name": "stem"}},
"flatten": {"operation": {"type": "Flatten", "parameters": {}}},
"fc1": {"operation": {"type": "Dense", "parameters": {"units": 1024, "activation": "relu"}}},
"fc2": {"operation": {"type": "Dense", "parameters": {"units": 10}}},
"softmax": {"operation": {"type": "Softmax", "parameters": {}}}
"stem": {"operation": {"type": "_cell", "parameters": {}, "attributes": {}, "cell_name": "stem"}},
"flatten": {"operation": {"type": "Flatten", "parameters": {}, "attributes": {}}},
"fc1": {"operation": {"type": "Dense", "parameters": {"units": 1024, "activation": "relu"}, "attributes": {}}},
"fc2": {"operation": {"type": "Dense", "parameters": {"units": 10}, "attributes": {}}},
"softmax": {"operation": {"type": "Softmax", "parameters": {}, "attributes": {}}}
},
"edges": [
......@@ -23,10 +23,10 @@
"stem": {
"nodes": {
"conv1": {"operation": {"type": "Conv2D", "parameters": {"filters": 32, "kernel_size": 5, "activation": "relu"}}},
"pool1": {"operation": {"type": "MaxPool2D", "parameters": {"pool_size": 2}}},
"conv2": {"operation": {"type": "Conv2D", "parameters": {"filters": 64, "kernel_size": 5, "activation": "relu"}}},
"pool2": {"operation": {"type": "MaxPool2D", "parameters": {"pool_size": 2}}}
"conv1": {"operation": {"type": "Conv2D", "parameters": {"filters": 32, "kernel_size": 5, "activation": "relu"}, "attributes": {}}},
"pool1": {"operation": {"type": "MaxPool2D", "parameters": {"pool_size": 2}, "attributes": {}}},
"conv2": {"operation": {"type": "Conv2D", "parameters": {"filters": 64, "kernel_size": 5, "activation": "relu"}, "attributes": {}}},
"pool2": {"operation": {"type": "MaxPool2D", "parameters": {"pool_size": 2}, "attributes": {}}}
},
"edges": [
......
......@@ -24,12 +24,12 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
conv_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')[0]
relu_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.activation.ReLU')[0]
pool_node = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.pooling.MaxPool2d')[0]
self.assertEqual(conv_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(conv_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.parameters.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.parameters.get('output_shape'), [[1, 1, 111, 111]])
self.assertEqual(conv_node.operation.attributes.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(conv_node.operation.attributes.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.attributes.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(relu_node.operation.attributes.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.attributes.get('input_shape'), [[1, 1, 222, 222]])
self.assertEqual(pool_node.operation.attributes.get('output_shape'), [[1, 1, 111, 111]])
def test_nested_module(self):
class ConvRelu(nn.Module):
......@@ -54,8 +54,8 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
# check if shape propagation works
cell_node = model_ir.get_nodes_by_type('_cell')[0]
self.assertEqual(cell_node.operation.parameters.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(cell_node.operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(cell_node.operation.attributes.get('input_shape'), [[1, 3, 224, 224]])
self.assertEqual(cell_node.operation.attributes.get('output_shape'), [[1, 1, 222, 222]])
def test_layerchoice(self):
class ConvNet(nn.Module):
......@@ -75,5 +75,5 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
# check shape info of each candidates
conv_nodes = model_ir.get_nodes_by_type('__torch__.torch.nn.modules.conv.Conv2d')
self.assertEqual(conv_nodes[0].operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(conv_nodes[1].operation.parameters.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(conv_nodes[0].operation.attributes.get('output_shape'), [[1, 1, 222, 222]])
self.assertEqual(conv_nodes[1].operation.attributes.get('output_shape'), [[1, 1, 222, 222]])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment