"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "19ad49fb3b847ce0992c68f57ad9940c2f2b2c44"
Unverified Commit bbc58d88 authored by Super Daniel's avatar Super Daniel Committed by GitHub
Browse files

[fx] fix MetaInfoProp for incorrect calculations and add detections for inplace op. (#1466)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] add rules to linearize computation graphs for searching. (#2)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] fix test and algorithm bugs in activation checkpointing.

* [fx] polish ckpt_test.

* [fx] add rules to linearize computation graphs for searching.

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] fix inconsistencies.

* [fx] fix MetaInfoProp.

* [fx] fix MetaInfoProp.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.
parent e7383f57
......@@ -36,20 +36,20 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor)
def _compute_node_numel(node_metadata: any) -> int:
def _compute_activation_size(node_metadata: any) -> int:
"""
Compute numel of a node with ``tensor_meta`` attribute.
"""
node_numel = 0
if isinstance(node_metadata, TensorMetadata):
node_numel += node_metadata.numel
node_numel += node_metadata.numel * torch.tensor([], dtype=node_metadata.dtype).element_size()
elif isinstance(node_metadata, dict):
value_list = [v for _, v in node_metadata.items()]
node_numel += _compute_node_numel(value_list)
node_numel += _compute_activation_size(value_list)
else:
for element in node_metadata:
node_numel += _compute_node_numel(element)
node_numel += _compute_activation_size(element)
return node_numel
......@@ -105,6 +105,7 @@ class MetaInfoProp(torch.fx.Interpreter):
"""
def run_node(self, n: Node) -> Any:
# TODO: We might run_node(n) with meta data, and count FLOPS for each node
result = super().run_node(n)
def extract_tensor_meta(obj):
......@@ -116,24 +117,20 @@ class MetaInfoProp(torch.fx.Interpreter):
meta = _map_aggregate(result, extract_tensor_meta)
n.meta['tensor_meta'] = meta
# get byte size for each element
size_per_elem_bytes = torch.tensor([], dtype=meta.dtype).element_size()
# compute the total size of activation tensors
total_activation_size = _compute_node_numel(n.meta['tensor_meta'])
# compute the total size of model parameters
total_activation_size = 0
total_param_size = 0
if n.op == 'call_module':
target_module = n.graph.owning_module.get_submodule(n.target)
if not getattr(target_module, 'inplace', False):
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
for param in target_module.parameters():
total_param_size += param.numel()
# compute the total memory cost of activation tensors and model parameters
total_activation_size *= size_per_elem_bytes
total_param_size *= size_per_elem_bytes
total_param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
elif n.op == 'call_function':
if 'inplace' not in n.kwargs:
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
else:
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
# TODO: node.node_size is not an original attribute
setattr(n, 'node_size', total_activation_size + total_param_size)
setattr(n, 'param_size', total_param_size)
setattr(n, 'activation_size', total_activation_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment