Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d7634af5
Commit
d7634af5
authored
Nov 11, 2022
by
oahzxl
Browse files
finish memory estimation
parent
22f9c60b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
47 deletions
+80
-47
chunk_codegen.py
chunk_codegen.py
+70
-37
chunk_codegen_run.py
chunk_codegen_run.py
+10
-10
No files found.
chunk_codegen.py
View file @
d7634af5
...
...
@@ -55,15 +55,49 @@ def _get_last_usr(nodes):
return
user_to_last_uses
def
_delete_free_var_from_last_use
(
user_to_last_uses
):
for
key
,
value
in
user_to_last_uses
.
items
():
for
n
in
value
:
if
n
.
op
==
'placeholder'
:
user_to_last_uses
[
key
].
remove
(
n
)
def
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
False
):
mem
=
0
not_contiguous_ops
=
[
'transpose'
,
'permute'
]
if
node
.
op
==
'call_function'
and
'matmul'
in
node
.
name
:
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# matmul won't change origin tensor, but create a tmp copy
mem
+=
_get_output_node_size
(
n
)
elif
node
.
op
==
'call_module'
:
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# module will just make origin tensor to contiguous
if
delete
:
not_contiguous_list
.
remove
(
n
)
elif
node
.
op
==
'call_method'
and
any
(
i
in
node
.
name
for
i
in
not_contiguous_ops
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
elif
any
(
i
in
node
.
args
for
i
in
not_contiguous_list
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
return
mem
def
_estimate_inference_mem
(
gm
:
torch
.
fx
.
GraphModule
):
act_memory
=
0
act_memory
=
0
.0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
not_contiguous_list
=
[]
user_to_last_uses
=
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
_delete_free_var_from_last_use
(
user_to_last_uses
)
for
node
in
gm
.
graph
.
nodes
:
# if node is placeholder, just add the size of the node
if
node
.
op
==
'placeholder'
:
act_memory
+=
_get_meta_node_size
(
node
)
act_memory
+=
_get_meta_node_size
(
node
)
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
act_memory_after_node_log
.
append
(
act_memory
)
# skip output
...
...
@@ -72,25 +106,21 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule):
# node is an operation, calculate tmp, output node and delete node memory
else
:
# forward memory
act_memory
+=
calculate_fwd_tmp
(
node
)
# act_memory += calculate_fwd_out(node)
act_memory
+=
_get_output_node_size
(
node
)
act_memory
+=
_get_contiguous_memory
(
node
,
not_contiguous_list
)
/
(
1024
**
2
)
act_memory
+=
_get_output_node_size
(
node
)
/
(
1024
**
2
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
calculate_fwd_tmp
(
node
)
act_memory
-=
_get_
delete_node_size
(
node
,
user_to
_l
a
st
_uses
)
act_memory
-=
_get_delete_node_size
(
node
,
user_to_last_uses
)
/
(
1024
**
2
)
act_memory
-=
_get_
contiguous_memory
(
node
,
not_contiguous
_l
i
st
,
delete
=
True
)
/
(
1024
**
2
)
act_memory_after_node_log
.
append
(
act_memory
)
act_memory_peak_log
=
[
float
(
i
)
/
(
1024
**
2
)
for
i
in
act_memory_peak_log
]
act_memory_after_node_log
=
[
float
(
i
)
/
(
1024
**
2
)
for
i
in
act_memory_after_node_log
]
print
(
"no chunk"
)
_print_mem_log
(
act_memory_peak_log
,
"peak"
)
_print_mem_log
(
act_memory_after_node_log
,
"after"
)
_print_mem_log
(
act_memory_peak_log
,
list
(
gm
.
graph
.
nodes
),
"peak"
)
_print_mem_log
(
act_memory_after_node_log
,
list
(
gm
.
graph
.
nodes
),
"after"
)
param_memory
=
parameter_size
(
gm
)
return
(
act_memory
+
param_memory
)
/
(
1024
**
2
),
param_memory
/
(
1024
**
2
)
return
act_memory
+
param_memory
,
param_memory
def
_get_chunk_ratio
(
node
,
chunk_dim
,
chunk_size
):
...
...
@@ -111,19 +141,23 @@ def _get_chunk_delete_node_size(user, user_to_last_uses, chunk_ratio, node_list,
return
delete_size
def
_print_mem_log
(
log
,
title
=
None
):
def
_print_mem_log
(
log
,
nodes
,
title
=
None
):
if
title
:
print
(
"%-8s"
%
title
,
end
=
' '
)
for
i
in
log
:
print
(
"%.2f "
%
i
,
end
=
''
)
print
(
title
)
for
idx
,
(
l
,
n
)
in
enumerate
(
zip
(
log
,
nodes
)):
print
(
"%s:%.2f
\t
"
%
(
n
.
name
,
l
),
end
=
''
)
if
(
idx
+
1
)
%
3
==
0
:
print
(
""
)
print
(
"
\n
"
)
def
_estimate_chunk_inference_mem
(
gm
:
torch
.
fx
.
GraphModule
,
start_nodes
,
end_nodes
,
chunk_dims
,
chunk_sizes
):
act_memory
=
0
act_memory
=
0
.0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
not_contiguous_list
=
[]
user_to_last_uses
=
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
_delete_free_var_from_last_use
(
user_to_last_uses
)
within_chunk
=
False
region_idx
=
0
chunk_ratio
=
1
# use it to estimate chunk mem
...
...
@@ -134,11 +168,11 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
if
idx
in
start_nodes
:
within_chunk
=
True
chunk_ratio
=
_get_chunk_ratio
(
node
,
chunk_dims
[
region_idx
],
chunk_sizes
[
region_idx
])
act_memory
+=
_get_output_node_size
(
node_list
[
end_nodes
[
region_idx
]])
act_memory
+=
_get_output_node_size
(
node_list
[
end_nodes
[
region_idx
]])
/
(
1024
**
2
)
# if node is placeholder, just add the size of the node
if
node
.
op
==
'placeholder'
:
act_memory
+=
_get_meta_node_size
(
node
)
*
chunk_ratio
act_memory
+=
_get_meta_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
# skip output
elif
node
.
op
==
'output'
:
...
...
@@ -146,36 +180,33 @@ def _estimate_chunk_inference_mem(gm: torch.fx.GraphModule, start_nodes, end_nod
# node is an operation, calculate tmp, output node and delete node memory
else
:
# forward memory
act_memory
+=
calculate_fwd_tmp
(
node
)
*
chunk_ratio
# act_memory += calculate_fwd_out(node)
act_memory
+=
_get_output_node_size
(
node
)
*
chunk_ratio
act_memory
+=
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory
+=
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
calculate_fwd_tmp
(
nod
e
)
*
chunk_ratio
act_memory
-=
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
Tru
e
)
*
chunk_ratio
/
(
1024
**
2
)
if
within_chunk
:
act_memory
-=
_get_chunk_delete_node_size
(
node
,
user_to_last_uses
,
chunk_ratio
,
node_list
,
start_nodes
[
region_idx
],
end_nodes
[
region_idx
])
node
,
user_to_last_uses
,
chunk_ratio
,
node_list
,
start_nodes
[
region_idx
],
end_nodes
[
region_idx
])
/
(
1024
**
2
)
else
:
act_memory
-=
_get_delete_node_size
(
node
,
user_to_last_uses
)
act_memory
-=
_get_delete_node_size
(
node
,
user_to_last_uses
)
/
(
1024
**
2
)
if
idx
in
end_nodes
:
act_memory
-=
_get_output_node_size
(
node
)
*
chunk_ratio
act_memory
-=
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
within_chunk
=
False
chunk_ratio
=
1
region_idx
+=
1
act_memory_after_node_log
.
append
(
act_memory
)
act_memory_peak_log
=
[
float
(
i
)
/
(
1024
**
2
)
for
i
in
act_memory_peak_log
]
act_memory_after_node_log
=
[
float
(
i
)
/
(
1024
**
2
)
for
i
in
act_memory_after_node_log
]
print
(
"chunk"
)
_print_mem_log
(
act_memory_peak_log
,
"peak"
)
_print_mem_log
(
act_memory_after_node_log
,
"after"
)
_print_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
_print_mem_log
(
act_memory_after_node_log
,
node_list
,
"after"
)
param_memory
=
parameter_size
(
gm
)
return
(
act_memory
+
param_memory
)
/
(
1024
**
2
),
param_memory
/
(
1024
**
2
)
return
act_memory
+
param_memory
,
param_memory
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
...
...
@@ -516,7 +547,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
"""
# find the offload regions
chunk_regions
=
[(
2
,
6
)]
chunk_regions
=
[(
58
,
6
2
)]
chunk_starts
=
[
item
[
0
]
for
item
in
chunk_regions
]
chunk_ends
=
[
item
[
1
]
for
item
in
chunk_regions
]
chunk_inputs
=
[]
...
...
@@ -684,6 +715,8 @@ if CODEGEN_AVAILABLE:
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
_delete_free_var_from_last_use
(
user_to_last_uses
)
# NOTE: we add a variable to distinguish body and ckpt_func
def
delete_unused_values
(
user
:
Node
,
body
,
to_keep
=
[]):
"""
...
...
chunk_codegen_run.py
View file @
d7634af5
...
...
@@ -32,14 +32,14 @@ def _is_all_param_close(m: torch.nn.Module, gm: GraphModule) -> bool:
def
_test_fwd_and_bwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
):
#
now_mem = torch.cuda.memory_allocated() / 1024**2
# max_mem = torch.cuda.max_memory_allocated() / 1024**2
# print("now:%.2f max:%.2f" %(torch.cuda.memory_allocated() / 1024**2, torch.cuda.max_memory_allocated() / 1024**2)
)
# with torch.no_grad
()
:
#
fx_out
= gm(node, pair
)
#
new_now_mem = torch.cuda.memory_allocated() / 1024**2
#
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
#
print("now:%.2f max:%.2f" %(new_now_mem - now_mem, new_max_mem -
max
_mem))
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
with
torch
.
no_grad
():
node0
=
node
.
clone
(
)
pair0
=
pair
.
clone
()
node1
,
pair1
=
gm
(
node
0
,
pair
0
)
new_now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
new_max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
print
(
"now:%.2f max:%.2f"
%
(
new_now_mem
-
now_mem
,
new_max_mem
-
now
_mem
))
# test forward
with
torch
.
no_grad
():
...
...
@@ -63,8 +63,8 @@ def _run_offload_codegen(rank):
# build model and input
model
=
evoformer_base
().
cuda
()
node
=
torch
.
randn
(
1
,
1
6
,
3
2
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
3
2
,
3
2
,
128
).
cuda
()
node
=
torch
.
randn
(
1
,
1
00
,
3
00
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
3
00
,
3
00
,
128
).
cuda
()
# trace the module and replace codegen
graph
=
ColoTracer
().
trace
(
model
,
meta_args
=
{
'node'
:
node
.
to
(
torch
.
device
(
'meta'
)),
'pair'
:
pair
.
to
(
torch
.
device
(
'meta'
))})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment