Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8cca684c
Commit
8cca684c
authored
Nov 08, 2022
by
oahzxl
Browse files
finish memory estimation
parent
12301dd2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
88 additions
and
15 deletions
+88
-15
chunk_codegen.py
chunk_codegen.py
+88
-15
No files found.
chunk_codegen.py
View file @
8cca684c
...
...
@@ -85,25 +85,97 @@ def _estimate_inference_mem(gm: torch.fx.GraphModule):
act_memory_peak_log
=
[
float
(
i
)
/
(
1024
**
2
)
for
i
in
act_memory_peak_log
]
act_memory_after_node_log
=
[
float
(
i
)
/
(
1024
**
2
)
for
i
in
act_memory_after_node_log
]
# for i in act_memory_peak_log:
# print("%.2f " % i, end='')
# print("\n")
# for i in act_memory_after_node_log:
# print("%.2f " % i, end='')
# print("\n")
print
(
"no chunk"
)
_print_mem_log
(
act_memory_peak_log
,
"peak"
)
_print_mem_log
(
act_memory_after_node_log
,
"after"
)
param_memory
=
parameter_size
(
gm
)
return
(
act_memory
+
param_memory
)
/
(
1024
**
2
),
param_memory
/
(
1024
**
2
)
def
_estimate_chunk_forward_mem
(
gm
:
torch
.
fx
.
GraphModule
,
start_node
,
end_node
,
chunk_size
):
node_size
=
0
param_size
=
0
for
node
in
gm
.
graph
.
nodes
:
node_size
+=
calculate_fwd_tmp
(
node
)
node_size
+=
calculate_fwd_out
(
node
)
param_size
=
parameter_size
(
gm
)
return
(
node_size
+
param_size
)
/
1024
**
2
,
param_size
/
1024
**
2
def
_get_chunk_ratio
(
node
,
chunk_dim
,
chunk_size
):
shape
=
node
.
meta
[
'tensor_meta'
].
shape
chunk_ratio
=
float
(
chunk_size
)
/
shape
[
chunk_dim
]
return
chunk_ratio
def
_get_chunk_delete_node_size
(
user
,
user_to_last_uses
,
chunk_ratio
,
node_list
,
start_node
,
end_node
):
if
user
.
op
in
(
'placeholder'
,
'output'
):
return
0
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
delete_size
=
0
for
n
in
nodes_to_delete
:
node_idx
=
_find_idx_by_name
(
n
.
name
,
node_list
)
if
start_node
<=
node_idx
<
end_node
:
delete_size
+=
_get_output_node_size
(
n
)
*
chunk_ratio
return
delete_size
def
_print_mem_log
(
log
,
title
=
None
):
if
title
:
print
(
"%-8s"
%
title
,
end
=
' '
)
for
i
in
log
:
print
(
"%.2f "
%
i
,
end
=
''
)
print
(
""
)
def
_estimate_chunk_inference_mem
(
gm
:
torch
.
fx
.
GraphModule
,
start_nodes
,
end_nodes
,
chunk_dims
,
chunk_sizes
):
act_memory
=
0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
user_to_last_uses
=
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
within_chunk
=
False
region_idx
=
0
chunk_ratio
=
1
# use it to estimate chunk mem
node_list
=
list
(
gm
.
graph
.
nodes
)
for
idx
,
node
in
enumerate
(
node_list
):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if
idx
in
start_nodes
:
within_chunk
=
True
chunk_ratio
=
_get_chunk_ratio
(
node
,
chunk_dims
[
region_idx
],
chunk_sizes
[
region_idx
])
act_memory
+=
_get_output_node_size
(
node_list
[
end_nodes
[
region_idx
]])
# if node is placeholder, just add the size of the node
if
node
.
op
==
'placeholder'
:
act_memory
+=
_get_meta_node_size
(
node
)
*
chunk_ratio
act_memory_peak_log
.
append
(
act_memory
)
# skip output
elif
node
.
op
==
'output'
:
continue
# node is an operation, calculate tmp, output node and delete node memory
else
:
# forward memory
act_memory
+=
calculate_fwd_tmp
(
node
)
*
chunk_ratio
# act_memory += calculate_fwd_out(node)
act_memory
+=
_get_output_node_size
(
node
)
*
chunk_ratio
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
calculate_fwd_tmp
(
node
)
*
chunk_ratio
if
within_chunk
:
act_memory
-=
_get_chunk_delete_node_size
(
node
,
user_to_last_uses
,
chunk_ratio
,
node_list
,
start_nodes
[
region_idx
],
end_nodes
[
region_idx
])
else
:
act_memory
-=
_get_delete_node_size
(
node
,
user_to_last_uses
)
if
idx
in
end_nodes
:
act_memory
-=
_get_output_node_size
(
node
)
*
chunk_ratio
within_chunk
=
False
chunk_ratio
=
1
region_idx
+=
1
act_memory_after_node_log
.
append
(
act_memory
)
act_memory_peak_log
=
[
float
(
i
)
/
(
1024
**
2
)
for
i
in
act_memory_peak_log
]
act_memory_after_node_log
=
[
float
(
i
)
/
(
1024
**
2
)
for
i
in
act_memory_after_node_log
]
print
(
"chunk"
)
_print_mem_log
(
act_memory_peak_log
,
"peak"
)
_print_mem_log
(
act_memory_after_node_log
,
"after"
)
param_memory
=
parameter_size
(
gm
)
return
(
act_memory
+
param_memory
)
/
(
1024
**
2
),
param_memory
/
(
1024
**
2
)
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
...
...
@@ -444,7 +516,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
"""
# find the offload regions
chunk_regions
=
[(
2
,
5
)]
chunk_regions
=
[(
2
,
6
)]
chunk_starts
=
[
item
[
0
]
for
item
in
chunk_regions
]
chunk_ends
=
[
item
[
1
]
for
item
in
chunk_regions
]
chunk_inputs
=
[]
...
...
@@ -452,6 +524,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
within_chunk_region
=
False
node_list
=
list
(
nodes
)
_estimate_chunk_inference_mem
(
meta_graph
,
chunk_starts
,
chunk_ends
,
[
1
],
[
2
])
_estimate_inference_mem
(
meta_graph
)
# find the input and output var names for each offload region
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment