Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
54a34a7e
Commit
54a34a7e
authored
Nov 15, 2022
by
oahzxl
Browse files
update active log
parent
fad3b6d1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
13 deletions
+43
-13
chunk_codegen.py
chunk_codegen.py
+43
-13
No files found.
chunk_codegen.py
View file @
54a34a7e
...
...
@@ -407,18 +407,41 @@ class MemoryEstimator(object):
x
=
x
.
numel
*
torch
.
tensor
([],
dtype
=
x
.
dtype
).
element_size
()
return
x
def
_get_output_node
_size
(
self
,
n
):
def
_get_output_node
(
self
,
n
):
fwd_out
=
{
x
.
uuid
:
x
for
x
in
n
.
meta
[
"fwd_out"
]
if
isinstance
(
x
,
torch
.
Tensor
)
and
hasattr
(
x
,
'uuid'
)}
return
activation_size
(
fwd_out
)
out_size
=
activation_size
(
fwd_out
)
out_node
=
[
n
.
name
]
if
out_size
>
0
else
[]
return
out_size
,
out_node
def
_get_delete_node_size
(
self
,
user
,
user_to_last_uses
):
if
user
.
op
in
(
'placeholder'
,
'output'
):
return
0
def
_get_output_node_size
(
self
,
n
):
return
self
.
_get_output_node
(
n
)[
0
]
def
_add_active_node
(
self
,
n
,
active_list
):
new_active
=
self
.
_get_output_node
(
n
)[
1
]
for
i
in
new_active
:
if
i
not
in
active_list
:
active_list
.
append
(
i
)
def
_get_delete_node
(
self
,
user
,
user_to_last_uses
):
delete_size
=
0
delete_node
=
[]
if
user
.
op
not
in
(
'placeholder'
,
'output'
):
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
len
(
nodes_to_delete
):
delete_size
=
sum
([
self
.
_get_output_node_size
(
i
)
for
i
in
nodes_to_delete
])
return
delete_size
return
0
out_node
=
[
self
.
_get_output_node
(
i
)
for
i
in
nodes_to_delete
]
delete_size
=
sum
([
i
[
0
]
for
i
in
out_node
])
for
i
in
range
(
len
(
out_node
)):
if
out_node
[
i
][
0
]
>
0
:
delete_node
.
append
(
out_node
[
i
][
1
][
0
])
return
delete_size
,
delete_node
def
_get_delete_node_size
(
self
,
user
,
user_to_last_uses
):
return
self
.
_get_delete_node
(
user
,
user_to_last_uses
)[
0
]
def
_remove_active_node
(
self
,
user
,
user_to_last_uses
,
active_list
):
delete_node
=
self
.
_get_delete_node
(
user
,
user_to_last_uses
)[
1
]
for
i
in
delete_node
:
active_list
.
remove
(
i
)
def
_get_last_usr
(
self
,
nodes
):
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
...
...
@@ -438,7 +461,7 @@ class MemoryEstimator(object):
mem
=
0
not_contiguous_ops
=
[
'transpose'
,
'permute'
]
if
node
.
op
==
'call_function'
and
'matmul'
in
node
.
name
:
if
node
.
op
==
'call_function'
and
any
(
n
in
node
.
name
for
n
in
[
'matmul'
,
'reshape'
])
:
for
n
in
node
.
args
:
if
n
in
not_contiguous_list
:
# matmul won't change origin tensor, but create a tmp copy
...
...
@@ -463,6 +486,8 @@ class MemoryEstimator(object):
act_memory_peak_log
=
[]
act_memory_after_node_log
=
[]
not_contiguous_list
=
[]
active_node_list
=
[]
active_node_list_log
=
[]
user_to_last_uses
=
self
.
_get_last_usr
(
list
(
gm
.
graph
.
nodes
))
_delete_free_var_from_last_use
(
user_to_last_uses
)
for
node
in
gm
.
graph
.
nodes
:
...
...
@@ -470,7 +495,7 @@ class MemoryEstimator(object):
if
node
.
op
==
'placeholder'
:
act_memory
+=
self
.
_get_meta_node_size
(
node
)
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
act
_memory_after
_node_l
og
.
append
(
act_memory
)
act
ive
_node_l
ist
.
append
(
node
.
name
)
# skip output
elif
node
.
op
==
'output'
:
continue
...
...
@@ -484,8 +509,12 @@ class MemoryEstimator(object):
# delete useless memory
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses
)
/
(
1024
**
2
)
act_memory
-=
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
,
delete
=
True
)
/
(
1024
**
2
)
act_memory_after_node_log
.
append
(
act_memory
)
# log active node
self
.
_add_active_node
(
node
,
active_node_list
)
self
.
_remove_active_node
(
node
,
user_to_last_uses
,
active_node_list
)
act_memory_after_node_log
.
append
(
act_memory
)
active_node_list_log
.
append
(
copy
.
deepcopy
(
active_node_list
))
print
(
"no chunk"
)
self
.
_print_mem_log
(
act_memory_peak_log
,
list
(
gm
.
graph
.
nodes
),
"peak"
)
self
.
_print_mem_log
(
act_memory_after_node_log
,
list
(
gm
.
graph
.
nodes
),
"after"
)
...
...
@@ -551,7 +580,6 @@ class MemoryEstimator(object):
# node is an operation, calculate tmp, output node and delete node memory
else
:
# forward memory
# TODO: permute will create a tmp copy if not contiguous
act_memory
+=
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory
+=
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
# record max act memory
...
...
@@ -694,9 +722,11 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
within_chunk_region
=
False
node_list
=
list
(
nodes
)
memory_estimator
=
MemoryEstimator
()
memory_estimator
.
estimate_chunk_inference_mem
(
meta_graph
,
chunk_starts
,
chunk_ends
,
[
1
],
[
2
])
memory_estimator
.
estimate_inference_mem
(
meta_graph
)
node_index_tracer
=
NodeIndexTracer
(
meta_graph
)
node_index_tracer
.
trace_node_idx
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment