Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d95cfe26
Commit
d95cfe26
authored
Nov 07, 2022
by
oahzxl
Browse files
basic memory
parent
c35718e8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
13 deletions
+90
-13
chunk_codegen.py
chunk_codegen.py
+81
-2
chunk_codegen_run.py
chunk_codegen_run.py
+9
-11
No files found.
chunk_codegen.py
View file @
d95cfe26
...
@@ -6,6 +6,7 @@ from typing import List, Callable, Any, Tuple, Dict, Iterable
...
@@ -6,6 +6,7 @@ from typing import List, Callable, Any, Tuple, Dict, Iterable
try
:
try
:
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
CodeGen
,
_origin_type_map
,
inplace_methods
,
_CustomBuiltin
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
CodeGen
,
_origin_type_map
,
inplace_methods
,
_CustomBuiltin
from
colossalai.fx.profiler
import
calculate_fwd_out
,
calculate_fwd_tmp
,
parameter_size
,
activation_size
CODEGEN_AVAILABLE
=
True
CODEGEN_AVAILABLE
=
True
except
:
except
:
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
_origin_type_map
,
_format_args
,
_CustomBuiltin
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
_origin_type_map
,
_format_args
,
_CustomBuiltin
...
@@ -18,6 +19,82 @@ else:
...
@@ -18,6 +19,82 @@ else:
__all__
=
[
'python_code_with_activation_checkpoint'
]
__all__
=
[
'python_code_with_activation_checkpoint'
]
def
_get_meta_node_size
(
x
):
x
=
x
.
meta
[
'tensor_meta'
]
x
=
x
.
numel
*
torch
.
tensor
([],
dtype
=
x
.
dtype
).
element_size
()
return
x
def
_get_output_node_size
(
n
):
fwd_out
=
{
x
.
uuid
:
x
for
x
in
n
.
meta
[
"fwd_out"
]
if
isinstance
(
x
,
torch
.
Tensor
)
and
hasattr
(
x
,
'uuid'
)}
return
activation_size
(
fwd_out
)
def
_get_delete_node_size
(
user
,
user_to_last_uses
):
if
user
.
op
in
(
'placeholder'
,
'output'
):
return
0
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
len
(
nodes_to_delete
):
delete_size
=
sum
([
_get_output_node_size
(
i
)
for
i
in
nodes_to_delete
])
return
delete_size
return
0
def
_get_last_usr
(
nodes
):
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
return
user_to_last_uses
def
_estimate_inference_mem
(
gm
:
torch
.
fx
.
GraphModule
):
act_memory
=
0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
user_to_last_uses
=
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
for
node
in
gm
.
graph
.
nodes
:
# if node is placeholder, just add the size of the node
if
node
.
op
==
'placeholder'
:
act_memory
+=
_get_meta_node_size
(
node
)
# skip output
elif
node
.
op
==
'output'
:
continue
# node is an operation, calculate tmp, output node and delete node memory
else
:
# forward memory
act_memory
+=
calculate_fwd_tmp
(
node
)
# act_memory += calculate_fwd_out(node)
act_memory
+=
_get_output_node_size
(
node
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
calculate_fwd_tmp
(
node
)
act_memory
-=
_get_delete_node_size
(
node
,
user_to_last_uses
)
act_memory_after_node_log
.
append
(
act_memory
)
act_memory_peak_log
=
[
float
(
i
)
/
(
1024
**
2
)
for
i
in
act_memory_peak_log
]
param_memory
=
parameter_size
(
gm
)
return
(
act_memory
+
param_memory
)
/
(
1024
**
2
),
param_memory
/
(
1024
**
2
)
def
_estimate_chunk_forward_mem
(
gm
:
torch
.
fx
.
GraphModule
,
start_node
,
end_node
,
chunk_size
):
node_size
=
0
param_size
=
0
for
node
in
gm
.
graph
.
nodes
:
node_size
+=
calculate_fwd_tmp
(
node
)
node_size
+=
calculate_fwd_out
(
node
)
param_size
=
parameter_size
(
gm
)
return
(
node_size
+
param_size
)
/
1024
**
2
,
param_size
/
1024
**
2
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
new_shape
=
"["
new_shape
=
"["
for
idx
,
i
in
enumerate
(
shape
):
for
idx
,
i
in
enumerate
(
shape
):
...
@@ -342,7 +419,7 @@ def emit_ckpt_func(body,
...
@@ -342,7 +419,7 @@ def emit_ckpt_func(body,
body
.
append
(
usage
)
body
.
append
(
usage
)
def
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node_func
,
delete_unused_value_func
,
meta_nodes
):
def
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node_func
,
delete_unused_value_func
,
meta_nodes
,
meta_graph
):
"""Emit code with nested activation checkpoint
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
this function to emit the activation checkpoint codes.
...
@@ -364,6 +441,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -364,6 +441,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
within_chunk_region
=
False
within_chunk_region
=
False
node_list
=
list
(
nodes
)
node_list
=
list
(
nodes
)
_estimate_inference_mem
(
meta_graph
)
# find the input and output var names for each offload region
# find the input and output var names for each offload region
for
idx
,
(
start
,
end
)
in
enumerate
(
chunk_regions
):
for
idx
,
(
start
,
end
)
in
enumerate
(
chunk_regions
):
...
@@ -418,6 +496,7 @@ if CODEGEN_AVAILABLE:
...
@@ -418,6 +496,7 @@ if CODEGEN_AVAILABLE:
class
ChunkCodeGen
(
CodeGen
):
class
ChunkCodeGen
(
CodeGen
):
def
__init__
(
self
,
meta_graph
):
def
__init__
(
self
,
meta_graph
):
super
().
__init__
()
super
().
__init__
()
self
.
meta_graph
=
meta_graph
self
.
meta_node
=
list
(
meta_graph
.
graph
.
nodes
)
self
.
meta_node
=
list
(
meta_graph
.
graph
.
nodes
)
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
...
@@ -612,7 +691,7 @@ if CODEGEN_AVAILABLE:
...
@@ -612,7 +691,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
# will use nested type of activation checkpoint codegen
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
,
self
.
meta_node
)
emit_code_with_chunk
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
,
self
.
meta_node
,
self
.
meta_graph
)
if
len
(
body
)
==
0
:
if
len
(
body
)
==
0
:
# If the Graph has no non-placeholder nodes, no lines for the body
# If the Graph has no non-placeholder nodes, no lines for the body
...
...
chunk_codegen_run.py
View file @
d95cfe26
...
@@ -2,6 +2,7 @@ import copy
...
@@ -2,6 +2,7 @@ import copy
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
pytest
import
pytest
import
torch.fx
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.fx
import
GraphModule
from
torch.fx
import
GraphModule
from
colossalai.fx
import
ColoTracer
from
colossalai.fx
import
ColoTracer
...
@@ -56,18 +57,15 @@ def _run_offload_codegen(rank):
...
@@ -56,18 +57,15 @@ def _run_offload_codegen(rank):
pair
=
torch
.
randn
(
1
,
32
,
32
,
128
).
cuda
()
pair
=
torch
.
randn
(
1
,
32
,
32
,
128
).
cuda
()
# trace the module and replace codegen
# trace the module and replace codegen
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
'node'
:
node
.
to
(
torch
.
device
(
'meta'
)),
'pair'
:
pair
.
to
(
torch
.
device
(
'meta'
))})
graph
=
tracer
.
trace
(
model
)
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
gm_prop
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm_prop
)
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
'cuda:0'
),
MetaTensor
(
pair
,
fake_device
=
'cuda:0'
))
# now run it twice to get meta info in graph module, not necessary
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
'cuda:0'
),
MetaTensor
(
pair
,
fake_device
=
'cuda:0'
))
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
'cuda:0'
),
MetaTensor
(
pair
,
fake_device
=
'cuda:0'
))
# annotate the chunk part
# for node in graph.nodes:
# if node.name == "linear0":
# setattr(node, "activation_offload", [0, True, False])
# if node.name == "linear1":
# setattr(node, "activation_offload", [0, True, False])
codegen
=
ChunkCodeGen
(
gm_prop
)
codegen
=
ChunkCodeGen
(
gm_prop
)
graph
.
set_codegen
(
codegen
)
graph
.
set_codegen
(
codegen
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment