Unverified Commit 051592c6 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[fx] update MetaInforProp pass to process more complex node.meta (#1344)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c.

* [fx] update MetaInforProp pass to process more complex node.meta
parent 7a8702c0
...@@ -33,6 +33,24 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: ...@@ -33,6 +33,24 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
return TensorMetadata(shape, dtype, requires_grad, stride, numel) return TensorMetadata(shape, dtype, requires_grad, stride, numel)
def _compute_node_numel(node_metadata: any) -> int:
"""
Compute numel of a node with ``tensor_meta`` attribute.
"""
node_numel = 0
if isinstance(node_metadata, TensorMetadata):
node_numel += node_metadata.numel
elif isinstance(node_metadata, dict):
value_list = [v for _, v in node_metadata.items()]
node_numel += _compute_node_numel(value_list)
else:
for element in node_metadata:
node_numel += _compute_node_numel(element)
return node_numel
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
class MetaInfoProp(torch.fx.Interpreter): class MetaInfoProp(torch.fx.Interpreter):
""" """
...@@ -78,20 +96,13 @@ class MetaInfoProp(torch.fx.Interpreter): ...@@ -78,20 +96,13 @@ class MetaInfoProp(torch.fx.Interpreter):
return obj return obj
meta = map_aggregate(result, extract_tensor_meta) meta = map_aggregate(result, extract_tensor_meta)
if found_tensor: if found_tensor:
n.meta['tensor_meta'] = meta n.meta['tensor_meta'] = meta
else: else:
n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0) n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0)
# counting the total size of node outputs # counting the total size of node outputs
total_node_size = 0 total_node_size = _compute_node_numel(n.meta['tensor_meta'])
if isinstance(n.meta['tensor_meta'], TensorMetadata):
total_node_size += n.meta['tensor_meta'].numel
else:
for element in n.meta['tensor_meta']:
assert isinstance(
element, TensorMetadata
), f"``n.meta['tensor_meta']`` should be either TensorMetadata or a tuple of TensorMetadata."
total_node_size += element.numel
# counting the total size of parameters # counting the total size of parameters
total_param_size = 0 total_param_size = 0
if n.op == 'call_module': if n.op == 'call_module':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment