Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d9ca2f89
Commit
d9ca2f89
authored
Nov 15, 2022
by
oahzxl
Browse files
polish code
parent
54a34a7e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
60 deletions
+27
-60
chunk_codegen.py
chunk_codegen.py
+27
-60
No files found.
chunk_codegen.py
View file @
d9ca2f89
...
@@ -438,7 +438,7 @@ class MemoryEstimator(object):
...
@@ -438,7 +438,7 @@ class MemoryEstimator(object):
def
_get_delete_node_size
(
self
,
user
,
user_to_last_uses
):
def
_get_delete_node_size
(
self
,
user
,
user_to_last_uses
):
return
self
.
_get_delete_node
(
user
,
user_to_last_uses
)[
0
]
return
self
.
_get_delete_node
(
user
,
user_to_last_uses
)[
0
]
def
_remove_active_node
(
self
,
user
,
user_to_last_uses
,
active_list
):
def
_remove_
de
active_node
(
self
,
user
,
user_to_last_uses
,
active_list
):
delete_node
=
self
.
_get_delete_node
(
user
,
user_to_last_uses
)[
1
]
delete_node
=
self
.
_get_delete_node
(
user
,
user_to_last_uses
)[
1
]
for
i
in
delete_node
:
for
i
in
delete_node
:
active_list
.
remove
(
i
)
active_list
.
remove
(
i
)
...
@@ -481,48 +481,6 @@ class MemoryEstimator(object):
...
@@ -481,48 +481,6 @@ class MemoryEstimator(object):
return
mem
return
mem
def
estimate_inference_mem
(
self
,
gm
:
torch
.
fx
.
GraphModule
):
act_memory
=
0.0
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
not_contiguous_list
=
[]
active_node_list
=
[]
active_node_list_log
=
[]
user_to_last_uses
=
self
.
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
_delete_free_var_from_last_use
(
user_to_last_uses
)
for
node
in
gm
.
graph
.
nodes
:
# if node is placeholder, just add the size of the node
if
node
.
op
==
'placeholder'
:
act_memory
+=
self
.
_get_meta_node_size
(
node
)
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
active_node_list
.
append
(
node
.
name
)
# skip output
elif
node
.
op
==
'output'
:
continue
# node is an operation, calculate tmp, output node and delete node memory
else
:
# forward memory
act_memory
+=
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
/
(
1024
**
2
)
act_memory
+=
self
.
_get_output_node_size
(
node
)
/
(
1024
**
2
)
# record max act memory
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses
)
/
(
1024
**
2
)
act_memory
-=
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
/
(
1024
**
2
)
# log active node
self
.
_add_active_node
(
node
,
active_node_list
)
self
.
_remove_active_node
(
node
,
user_to_last_uses
,
active_node_list
)
act_memory_after_node_log
.
append
(
act_memory
)
active_node_list_log
.
append
(
copy
.
deepcopy
(
active_node_list
))
print
(
"no chunk"
)
self
.
_print_mem_log
(
act_memory_peak_log
,
list
(
gm
.
graph
.
nodes
),
"peak"
)
self
.
_print_mem_log
(
act_memory_after_node_log
,
list
(
gm
.
graph
.
nodes
),
"after"
)
param_memory
=
parameter_size
(
gm
)
return
act_memory
+
param_memory
,
param_memory
def
_get_chunk_ratio
(
self
,
node
,
chunk_dim
,
chunk_size
):
def
_get_chunk_ratio
(
self
,
node
,
chunk_dim
,
chunk_size
):
shape
=
node
.
meta
[
'tensor_meta'
].
shape
shape
=
node
.
meta
[
'tensor_meta'
].
shape
chunk_ratio
=
float
(
chunk_size
)
/
shape
[
chunk_dim
]
chunk_ratio
=
float
(
chunk_size
)
/
shape
[
chunk_dim
]
...
@@ -550,25 +508,28 @@ class MemoryEstimator(object):
...
@@ -550,25 +508,28 @@ class MemoryEstimator(object):
print
(
""
)
print
(
""
)
print
(
"
\n
"
)
print
(
"
\n
"
)
def
estimate_chunk_inference_mem
(
self
,
gm
:
torch
.
fx
.
GraphModule
,
start_nodes
=
None
,
end_nodes
=
None
,
chunk_dims
=
None
,
chunk_sizes
=
None
):
def
estimate_chunk_inference_mem
(
self
,
gm
:
torch
.
fx
.
GraphModule
,
start_nodes
,
end_nodes
,
chunk_dims
,
chunk_sizes
):
act_memory
=
0.0
act_memory
=
0.0
act_memory_peak_log
=
[]
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
act_memory_after_node_log
=
[]
active_node_list
=
[]
active_node_list_log
=
[]
not_contiguous_list
=
[]
not_contiguous_list
=
[]
node_list
=
list
(
gm
.
graph
.
nodes
)
user_to_last_uses
=
self
.
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
user_to_last_uses
=
self
.
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
_delete_free_var_from_last_use
(
user_to_last_uses
)
_delete_free_var_from_last_use
(
user_to_last_uses
)
within_chunk
=
False
region_idx
=
0
use_chunk
=
all
(
i
is
not
None
for
i
in
[
start_nodes
,
end_nodes
,
chunk_dims
,
chunk_sizes
])
chunk_within
=
False
chunk_region_idx
=
0
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_ratio
=
1
# use it to estimate chunk mem
node_list
=
list
(
gm
.
graph
.
nodes
)
for
idx
,
node
in
enumerate
(
node_list
):
for
idx
,
node
in
enumerate
(
node_list
):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if
idx
in
start_nodes
:
if
use_chunk
and
idx
in
start_nodes
:
within
_chunk
=
True
chunk_
within
=
True
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
chunk_dims
[
region_idx
],
chunk_sizes
[
region_idx
])
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
chunk_dims
[
chunk_
region_idx
],
chunk_sizes
[
chunk_
region_idx
])
act_memory
+=
self
.
_get_output_node_size
(
node_list
[
end_nodes
[
region_idx
]])
/
(
1024
**
2
)
act_memory
+=
self
.
_get_output_node_size
(
node_list
[
end_nodes
[
chunk_
region_idx
]])
/
(
1024
**
2
)
# if node is placeholder, just add the size of the node
# if node is placeholder, just add the size of the node
if
node
.
op
==
'placeholder'
:
if
node
.
op
==
'placeholder'
:
...
@@ -586,22 +547,28 @@ class MemoryEstimator(object):
...
@@ -586,22 +547,28 @@ class MemoryEstimator(object):
act_memory_peak_log
.
append
(
act_memory
)
act_memory_peak_log
.
append
(
act_memory
)
# delete useless memory
# delete useless memory
act_memory
-=
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory
-=
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
*
chunk_ratio
/
(
1024
**
2
)
if
within
_chunk
:
if
chunk_
within
:
act_memory
-=
self
.
_get_chunk_delete_node_size
(
act_memory
-=
self
.
_get_chunk_delete_node_size
(
node
,
user_to_last_uses
,
chunk_ratio
,
node_list
,
node
,
user_to_last_uses
,
chunk_ratio
,
node_list
,
start_nodes
[
region_idx
],
end_nodes
[
region_idx
])
/
(
1024
**
2
)
start_nodes
[
chunk_
region_idx
],
end_nodes
[
chunk_
region_idx
])
/
(
1024
**
2
)
else
:
else
:
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses
)
/
(
1024
**
2
)
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses
)
/
(
1024
**
2
)
if
idx
in
end_nodes
:
# log active node
self
.
_add_active_node
(
node
,
active_node_list
)
self
.
_remove_deactive_node
(
node
,
user_to_last_uses
,
active_node_list
)
# if node in chunk end nodes, restore chunk settings
if
use_chunk
and
idx
in
end_nodes
:
act_memory
-=
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory
-=
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
within
_chunk
=
False
chunk_
within
=
False
chunk_ratio
=
1
chunk_ratio
=
1
region_idx
+=
1
chunk_
region_idx
+=
1
act_memory_after_node_log
.
append
(
act_memory
)
act_memory_after_node_log
.
append
(
act_memory
)
active_node_list_log
.
append
(
copy
.
deepcopy
(
active_node_list
))
print
(
"chunk"
)
print
(
"
with chunk"
if
use_chunk
else
"without
chunk"
)
self
.
_print_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
self
.
_print_mem_log
(
act_memory_peak_log
,
node_list
,
"peak"
)
self
.
_print_mem_log
(
act_memory_after_node_log
,
node_list
,
"after"
)
self
.
_print_mem_log
(
act_memory_after_node_log
,
node_list
,
"after"
)
...
@@ -725,7 +692,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -725,7 +692,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
memory_estimator
=
MemoryEstimator
()
memory_estimator
=
MemoryEstimator
()
memory_estimator
.
estimate_chunk_inference_mem
(
meta_graph
,
chunk_starts
,
chunk_ends
,
[
1
],
[
2
])
memory_estimator
.
estimate_chunk_inference_mem
(
meta_graph
,
chunk_starts
,
chunk_ends
,
[
1
],
[
2
])
memory_estimator
.
estimate_inference_mem
(
meta_graph
)
memory_estimator
.
estimate_
chunk_
inference_mem
(
meta_graph
)
node_index_tracer
=
NodeIndexTracer
(
meta_graph
)
node_index_tracer
=
NodeIndexTracer
(
meta_graph
)
node_index_tracer
.
trace_node_idx
()
node_index_tracer
.
trace_node_idx
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment