Commit d95cfe26 authored by oahzxl's avatar oahzxl
Browse files

basic memory

parent c35718e8
...@@ -6,6 +6,7 @@ from typing import List, Callable, Any, Tuple, Dict, Iterable ...@@ -6,6 +6,7 @@ from typing import List, Callable, Any, Tuple, Dict, Iterable
try: try:
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, parameter_size, activation_size
CODEGEN_AVAILABLE = True CODEGEN_AVAILABLE = True
except: except:
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin
...@@ -18,6 +19,82 @@ else: ...@@ -18,6 +19,82 @@ else:
__all__ = ['python_code_with_activation_checkpoint'] __all__ = ['python_code_with_activation_checkpoint']
def _get_meta_node_size(x):
x = x.meta['tensor_meta']
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
return x
def _get_output_node_size(n):
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
return activation_size(fwd_out)
def _get_delete_node_size(user, user_to_last_uses):
if user.op in ('placeholder', 'output'):
return 0
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
delete_size = sum([_get_output_node_size(i) for i in nodes_to_delete])
return delete_size
return 0
def _get_last_usr(nodes):
node_to_last_use: Dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {}
def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use:
node_to_last_use[n] = user
user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
return user_to_last_uses
def _estimate_inference_mem(gm: torch.fx.GraphModule):
act_memory = 0
act_memory_peak_log = []
act_memory_after_node_log = []
user_to_last_uses = _get_last_usr(list(gm.graph.nodes))
for node in gm.graph.nodes:
# if node is placeholder, just add the size of the node
if node.op == 'placeholder':
act_memory += _get_meta_node_size(node)
# skip output
elif node.op == 'output':
continue
# node is an operation, calculate tmp, output node and delete node memory
else:
# forward memory
act_memory += calculate_fwd_tmp(node)
# act_memory += calculate_fwd_out(node)
act_memory += _get_output_node_size(node)
# record max act memory
act_memory_peak_log.append(act_memory)
# delete useless memory
act_memory -= calculate_fwd_tmp(node)
act_memory -= _get_delete_node_size(node, user_to_last_uses)
act_memory_after_node_log.append(act_memory)
act_memory_peak_log = [float(i) / (1024 ** 2) for i in act_memory_peak_log]
param_memory = parameter_size(gm)
return (act_memory + param_memory) / (1024 ** 2), param_memory / (1024 ** 2)
def _estimate_chunk_forward_mem(gm: torch.fx.GraphModule, start_node, end_node, chunk_size):
node_size = 0
param_size = 0
for node in gm.graph.nodes:
node_size += calculate_fwd_tmp(node)
node_size += calculate_fwd_out(node)
param_size = parameter_size(gm)
return (node_size + param_size) / 1024**2, param_size / 1024**2
def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape):
new_shape = "[" new_shape = "["
for idx, i in enumerate(shape): for idx, i in enumerate(shape):
...@@ -342,7 +419,7 @@ def emit_ckpt_func(body, ...@@ -342,7 +419,7 @@ def emit_ckpt_func(body,
body.append(usage) body.append(usage)
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes): def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph):
"""Emit code with nested activation checkpoint """Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes. this function to emit the activation checkpoint codes.
...@@ -364,6 +441,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v ...@@ -364,6 +441,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
within_chunk_region = False within_chunk_region = False
node_list = list(nodes) node_list = list(nodes)
_estimate_inference_mem(meta_graph)
# find the input and output var names for each offload region # find the input and output var names for each offload region
for idx, (start, end) in enumerate(chunk_regions): for idx, (start, end) in enumerate(chunk_regions):
...@@ -418,6 +496,7 @@ if CODEGEN_AVAILABLE: ...@@ -418,6 +496,7 @@ if CODEGEN_AVAILABLE:
class ChunkCodeGen(CodeGen): class ChunkCodeGen(CodeGen):
def __init__(self, meta_graph): def __init__(self, meta_graph):
super().__init__() super().__init__()
self.meta_graph = meta_graph
self.meta_node = list(meta_graph.graph.nodes) self.meta_node = list(meta_graph.graph.nodes)
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
...@@ -612,7 +691,7 @@ if CODEGEN_AVAILABLE: ...@@ -612,7 +691,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we # if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen # will use nested type of activation checkpoint codegen
emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node) emit_code_with_chunk(body, ckpt_func, nodes, emit_node, delete_unused_values, self.meta_node, self.meta_graph)
if len(body) == 0: if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body # If the Graph has no non-placeholder nodes, no lines for the body
......
...@@ -2,6 +2,7 @@ import copy ...@@ -2,6 +2,7 @@ import copy
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import pytest import pytest
import torch.fx
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.fx import GraphModule from torch.fx import GraphModule
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
...@@ -56,18 +57,15 @@ def _run_offload_codegen(rank): ...@@ -56,18 +57,15 @@ def _run_offload_codegen(rank):
pair = torch.randn(1, 32, 32, 128).cuda() pair = torch.randn(1, 32, 32, 128).cuda()
# trace the module and replace codegen # trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True) graph = ColoTracer().trace(model, meta_args={'node': node.to(torch.device('meta')), 'pair': pair.to(torch.device('meta'))})
graph = tracer.trace(model) gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
gm_prop = torch.fx.GraphModule(model, graph) interp = MetaInfoProp(gm_prop)
interp = MetaInfoProp(gm_prop) interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0'))
# now run it twice to get meta info in graph module, not necessary
gm = torch.fx.GraphModule(model, graph)
interp = MetaInfoProp(gm)
interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0')) interp.propagate(MetaTensor(node, fake_device='cuda:0'), MetaTensor(pair, fake_device='cuda:0'))
# annotate the chunk part
# for node in graph.nodes:
# if node.name == "linear0":
# setattr(node, "activation_offload", [0, True, False])
# if node.name == "linear1":
# setattr(node, "activation_offload", [0, True, False])
codegen = ChunkCodeGen(gm_prop) codegen = ChunkCodeGen(gm_prop)
graph.set_codegen(codegen) graph.set_codegen(codegen)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment