Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d95cfe26
Commit
d95cfe26
authored
Nov 07, 2022
by
oahzxl
Browse files
basic memory
parent
c35718e8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
13 deletions
+90
-13
chunk_codegen.py
chunk_codegen.py
+81
-2
chunk_codegen_run.py
chunk_codegen_run.py
+9
-11
No files found.
chunk_codegen.py
View file @
d95cfe26
...
...
@@ -6,6 +6,7 @@ from typing import List, Callable, Any, Tuple, Dict, Iterable
try
:
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
CodeGen
,
_origin_type_map
,
inplace_methods
,
_CustomBuiltin
from
colossalai.fx.profiler
import
calculate_fwd_out
,
calculate_fwd_tmp
,
parameter_size
,
activation_size
CODEGEN_AVAILABLE
=
True
except
:
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
_origin_type_map
,
_format_args
,
_CustomBuiltin
...
...
@@ -18,6 +19,82 @@ else:
__all__
=
[
'python_code_with_activation_checkpoint'
]
def
_get_meta_node_size
(
x
):
x
=
x
.
meta
[
'tensor_meta'
]
x
=
x
.
numel
*
torch
.
tensor
([],
dtype
=
x
.
dtype
).
element_size
()
return
x
def
_get_output_node_size
(
n
):
fwd_out
=
{
x
.
uuid
:
x
for
x
in
n
.
meta
[
"fwd_out"
]
if
isinstance
(
x
,
torch
.
Tensor
)
and
hasattr
(
x
,
'uuid'
)}
return
activation_size
(
fwd_out
)
def
_get_delete_node_size
(
user
,
user_to_last_uses
):
if
user
.
op
in
(
'placeholder'
,
'output'
):
return
0
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
len
(
nodes_to_delete
):
delete_size
=
sum
([
_get_output_node_size
(
i
)
for
i
in
nodes_to_delete
])
return
delete_size
return
0
def
_get_last_usr
(
nodes
):
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
return
user_to_last_uses
def
_estimate_inference_mem
(
gm
:
torch
.
fx
.
GraphModule
):
act_memory
=
0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
user_to_last_uses
=
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
for
node
in
gm
.
graph
.
nodes
:
# if node is placeholder, just add the size of the node
if
node
.
op
==
'placeholder'
:
act_memory
+=
_get_meta_node_size
(
node
)
# skip output
elif
node
.
op
==
'output'
:
continue
# node is an operation, calculate tmp, output node and delete node memory
else
:
# forward memory
act_memory
+=
calculate_fwd_tmp
(
node
)
# act_memory += calculate_fwd_out(node)
act_memory
+=
_get_output_node_size
(
node
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
calculate_fwd_tmp
(
node
)
act_memory
-=
_get_delete_node_size
(
node
,
user_to_last_uses
)
act_memory_after_node_log
.
append
(
act_memory
)
act_memory_peak_log
=
[
float
(
i
)
/
(
1024
**
2
)
for
i
in
act_memory_peak_log
]
param_memory
=
parameter_size
(
gm
)
return
(
act_memory
+
param_memory
)
/
(
1024
**
2
),
param_memory
/
(
1024
**
2
)
def
_estimate_chunk_forward_mem
(
gm
:
torch
.
fx
.
GraphModule
,
start_node
,
end_node
,
chunk_size
):
node_size
=
0
param_size
=
0
for
node
in
gm
.
graph
.
nodes
:
node_size
+=
calculate_fwd_tmp
(
node
)
node_size
+=
calculate_fwd_out
(
node
)
param_size
=
parameter_size
(
gm
)
return
(
node_size
+
param_size
)
/
1024
**
2
,
param_size
/
1024
**
2
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
new_shape
=
"["
for
idx
,
i
in
enumerate
(
shape
):
...
...
@@ -342,7 +419,7 @@ def emit_ckpt_func(body,
body
.
append
(
usage
)
def
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node_func
,
delete_unused_value_func
,
meta_nodes
):
def
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node_func
,
delete_unused_value_func
,
meta_nodes
,
meta_graph
):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
...
...
@@ -364,6 +441,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
within_chunk_region
=
False
node_list
=
list
(
nodes
)
_estimate_inference_mem
(
meta_graph
)
# find the input and output var names for each offload region
for
idx
,
(
start
,
end
)
in
enumerate
(
chunk_regions
):
...
...
@@ -418,6 +496,7 @@ if CODEGEN_AVAILABLE:
class
ChunkCodeGen
(
CodeGen
):
def
__init__
(
self
,
meta_graph
):
super
().
__init__
()
self
.
meta_graph
=
meta_graph
self
.
meta_node
=
list
(
meta_graph
.
graph
.
nodes
)
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
...
...
@@ -612,7 +691,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
,
self
.
meta_node
)
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
,
self
.
meta_node
,
self
.
meta_graph
)
if
len
(
body
)
==
0
:
# If the Graph has no non-placeholder nodes, no lines for the body
...
...
chunk_codegen_run.py
View file @
d95cfe26
...
...
@@ -2,6 +2,7 @@ import copy
import
torch
import
torch.nn.functional
as
F
import
pytest
import
torch.fx
import
torch.multiprocessing
as
mp
from
torch.fx
import
GraphModule
from
colossalai.fx
import
ColoTracer
...
...
@@ -56,18 +57,15 @@ def _run_offload_codegen(rank):
pair
=
torch
.
randn
(
1
,
32
,
32
,
128
).
cuda
()
# trace the module and replace codegen
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
graph
=
tracer
.
trace
(
model
)
gm_prop
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm_prop
)
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
'node'
:
node
.
to
(
torch
.
device
(
'meta'
)),
'pair'
:
pair
.
to
(
torch
.
device
(
'meta'
))})
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
'cuda:0'
),
MetaTensor
(
pair
,
fake_device
=
'cuda:0'
))
# now run it twice to get meta info in graph module, not necessary
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
'cuda:0'
),
MetaTensor
(
pair
,
fake_device
=
'cuda:0'
))
# annotate the chunk part
# for node in graph.nodes:
# if node.name == "linear0":
# setattr(node, "activation_offload", [0, True, False])
# if node.name == "linear1":
# setattr(node, "activation_offload", [0, True, False])
codegen
=
ChunkCodeGen
(
gm_prop
)
graph
.
set_codegen
(
codegen
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment