Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e83e3c61
Commit
e83e3c61
authored
Dec 16, 2022
by
oahzxl
Browse files
update memory estimate
parent
de65e6c3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
111 additions
and
66 deletions
+111
-66
chunk_codegen.py
chunk_codegen.py
+111
-66
No files found.
chunk_codegen.py
View file @
e83e3c61
...
...
@@ -896,23 +896,22 @@ class IndexTracer(object):
def
_find_inherit_dim
(
self
,
input_node
,
input_dim
,
node
):
input_node_idx
=
_find_idx_by_name
(
input_node
.
name
,
self
.
nodes_list
)
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
node_trace_source
=
self
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
_get_node_shape
(
node
))):
if
(
input_node_idx
in
node_trace_source
[
node_dim
]
and
node_trace_source
[
node_dim
][
input_node_idx
]
==
input_dim
):
return
{
node_idx
:
node_dim
}
return
{}
return
node_dim
return
None
def
check_index_duplicate
(
self
,
chunk_infos
):
input_dim_after_node
=
{}
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
in
pu
t_dim
_after_node
.
update
(
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
nodes_list
[
k
])
)
in
heri
t_dim
=
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
nodes_list
[
k
])
if
inherit_dim
:
input_dim_after_node
[
k
]
=
inherit_dim
for
node
in
self
.
nodes_list
[
chunk_infos
[
"region"
][
0
]
:
chunk_infos
[
"region"
][
1
]
+
1
...
...
@@ -934,8 +933,8 @@ class IndexTracer(object):
class
MemoryEstimator
(
object
):
def
__init__
(
self
)
->
None
:
pass
def
__init__
(
self
,
index_tracer
:
IndexTracer
)
->
None
:
self
.
index_tracer
=
index_tracer
def
_get_meta_node_size
(
self
,
x
):
x
=
x
.
meta
[
"tensor_meta"
]
...
...
@@ -950,6 +949,8 @@ class MemoryEstimator(object):
}
out_size
=
activation_size
(
fwd_out
)
out_node
=
[
n
.
name
]
if
out_size
>
0
else
[]
# if any(i in n.name for i in ['transpose', 'permute', 'view']):
# out_size = 0
return
out_size
,
out_node
def
_get_output_node_size
(
self
,
n
):
...
...
@@ -961,11 +962,19 @@ class MemoryEstimator(object):
if
i
not
in
active_list
:
active_list
.
append
(
i
)
def
_get_delete_node
(
self
,
user
,
user_to_last_uses
):
def
_get_delete_node
(
self
,
user
,
user_to_last_uses
,
to_keep
=
None
):
delete_size
=
0
delete_node
=
[]
if
user
.
op
not
in
(
"placeholder"
,
"output"
):
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
if
to_keep
is
not
None
:
keep_list
=
[]
for
n
in
nodes_to_delete
:
if
n
.
name
in
to_keep
:
keep_list
.
append
(
n
)
for
n
in
keep_list
:
if
n
in
nodes_to_delete
:
nodes_to_delete
.
remove
(
n
)
if
len
(
nodes_to_delete
):
out_node
=
[
self
.
_get_output_node
(
i
)
for
i
in
nodes_to_delete
]
delete_size
=
sum
([
i
[
0
]
for
i
in
out_node
])
...
...
@@ -974,16 +983,31 @@ class MemoryEstimator(object):
delete_node
.
append
(
out_node
[
i
][
1
][
0
])
elif
nodes_to_delete
[
i
].
op
==
"placeholder"
:
delete_node
.
append
(
nodes_to_delete
[
i
].
name
)
# elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']):
# delete_node.append(nodes_to_delete[i].name)
return
delete_size
,
delete_node
def
_get_delete_node_size
(
self
,
user
,
user_to_last_uses
):
return
self
.
_get_delete_node
(
user
,
user_to_last_uses
)[
0
]
def
_get_delete_node_size
(
self
,
user
,
user_to_last_uses
,
to_keep
):
return
self
.
_get_delete_node
(
user
,
user_to_last_uses
,
to_keep
)[
0
]
def
_remove_deactive_node
(
self
,
user
,
user_to_last_uses
,
active_list
):
delete_node
=
self
.
_get_delete_node
(
user
,
user_to_last_uses
)[
1
]
for
i
in
delete_node
:
if
i
in
active_list
:
active_list
.
remove
(
i
)
def
_get_chunk_inputs_size
(
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
nodes_to_delete
=
[]
for
chunk_input
in
chunk_inputs
+
chunk_inputs_non_chunk
:
chunk_input_users
=
chunk_input
.
users
.
keys
()
chunk_input_users_idx
=
[
_find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
if
all
(
i
<=
chunk_end_idx
for
i
in
chunk_input_users_idx
):
if
chunk_input
not
in
nodes_to_delete
:
nodes_to_delete
.
append
(
chunk_input
)
out_node
=
[
self
.
_get_output_node
(
i
)
for
i
in
nodes_to_delete
]
delete_size
=
sum
([
i
[
0
]
for
i
in
out_node
])
return
delete_size
def
_get_last_usr
(
self
,
nodes
):
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
...
...
@@ -1000,7 +1024,8 @@ class MemoryEstimator(object):
def
_get_contiguous_memory
(
self
,
node
,
not_contiguous_list
,
delete
=
False
):
mem
=
0
not_contiguous_ops
=
[
"transpose"
,
"permute"
]
not_contiguous_ops
=
[
"permute"
]
inherit_contiguous_ops
=
[
"transpose"
,
"view"
]
if
node
.
op
==
"call_function"
and
any
(
n
in
node
.
name
for
n
in
[
"matmul"
,
"reshape"
]
...
...
@@ -1020,29 +1045,35 @@ class MemoryEstimator(object):
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
elif
any
(
i
in
node
.
args
for
i
in
not_contiguous_list
):
if
node
not
in
not_contiguous_list
:
not_contiguous_list
.
append
(
node
)
return
mem
def
_get_chunk_ratio
(
self
,
node
,
chunk_dim
,
chunk_size
):
sorted_dim
=
sorted
(
chunk_dim
,
key
=
lambda
x
:
list
(
x
.
keys
())[
0
])
dim
=
list
(
sorted_dim
[
-
1
].
values
())[
0
]
shape
=
node
.
meta
[
"tensor_meta"
].
shape
chunk_ratio
=
float
(
chunk_size
)
/
shape
[
dim
]
def
_get_chunk_ratio
(
self
,
node
,
chunk_inputs
,
chunk_inputs_dim
,
chunk_size
):
node_shape
=
_get_node_shape
(
node
)
node_source
=
self
.
index_tracer
.
_find_source_trace_from_node
(
node
)
for
(
input_node
,
input_node_dim
)
in
zip
(
chunk_inputs
,
chunk_inputs_dim
):
for
k
,
v
in
input_node_dim
.
items
():
inherit_dim
=
self
.
index_tracer
.
_find_inherit_dim
(
input_node
,
v
,
self
.
index_tracer
.
nodes_list
[
k
])
if
k
==
_find_idx_by_name
(
node
.
name
,
self
.
index_tracer
.
nodes_list
):
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
inherit_dim
]
return
chunk_ratio
for
dim
,
source
in
enumerate
(
node_source
):
if
k
in
source
and
source
[
k
]
==
inherit_dim
:
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
dim
]
return
chunk_ratio
return
1.
def
_get_chunk_delete_node_size
(
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
node_list
,
start_node
,
end_node
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
):
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
# return 0
if
user
.
op
in
(
"placeholder"
,
"output"
):
return
0
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
delete_size
=
0
for
n
in
nodes_to_delete
:
node_idx
=
_find_idx_by_name
(
n
.
name
,
node_list
)
if
start_node
<=
node_idx
<
end_node
:
if
n
.
name
in
chunk_inputs_names
:
continue
delete_size
+=
self
.
_get_output_node_size
(
n
)
*
chunk_ratio
return
delete_size
...
...
@@ -1071,10 +1102,7 @@ class MemoryEstimator(object):
def
estimate_chunk_inference_mem
(
self
,
gm
:
torch
.
fx
.
GraphModule
,
start_nodes
=
None
,
end_nodes
=
None
,
chunk_dims
=
None
,
chunk_sizes
=
None
,
chunk_infos
=
None
,
):
act_memory
=
0.0
act_memory_peak_log
=
[]
...
...
@@ -1087,36 +1115,53 @@ class MemoryEstimator(object):
user_to_last_uses_no_free_var
=
self
.
_get_last_usr
(
node_list
)
_delete_free_var_from_last_use
(
user_to_last_uses_no_free_var
)
use_chunk
=
all
(
i
is
not
None
for
i
in
[
start_nodes
,
end_nodes
,
chunk_dims
,
chunk_sizes
]
)
use_chunk
=
True
if
chunk_infos
is
not
None
else
False
chunk_within
=
False
chunk_region_idx
=
None
chunk_ratio
=
1
# use it to estimate chunk mem
chunk_size
=
1
chunk_inputs_names
=
[]
if
use_chunk
:
chunk_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
chunk_starts
=
[
i
[
0
]
for
i
in
chunk_regions
]
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_regions
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_infos
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
for
idx
,
node
in
enumerate
(
node_list
):
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
if
use_chunk
and
idx
in
start
_node
s
:
if
use_chunk
and
idx
in
chunk_
starts
:
chunk_within
=
True
chunk_region_idx
=
start_nodes
.
index
(
idx
)
chunk_region_idx
=
chunk_starts
.
index
(
idx
)
act_memory
+=
self
.
_get_output_node_size
(
chunk_outputs
[
chunk_region_idx
])
/
(
1024
**
2
)
# determine chunk ratio for current node
if
chunk_within
:
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
chunk_
dim
s
[
chunk_region_idx
],
chunk_
sizes
[
chunk_region_idx
]
node
,
chunk_
input
s
[
chunk_region_idx
],
chunk_
inputs_dim
[
chunk_region_idx
]
,
chunk_size
)
act_memory
+=
self
.
_get_output_node_size
(
node_list
[
end_nodes
[
chunk_region_idx
]]
)
/
(
1024
**
2
)
# if node is placeholder, just add the size of the node
if
node
.
op
==
"placeholder"
:
act_memory
+=
self
.
_get_meta_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
act_memory_peak_log
.
append
(
act_memory
)
active_node_list
.
append
(
node
.
name
)
# skip output
elif
node
.
op
==
"output"
:
continue
# node is an operation, calculate tmp, output node and delete node memory
# no change for non compute node
elif
_is_non_compute_node_except_placeholder
(
node
):
act_memory_peak_log
.
append
(
act_memory
)
# node is a compute op
# calculate tmp, output node and delete node memory
else
:
# forward memory
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
act_memory
+=
(
self
.
_get_contiguous_memory
(
node
,
not_contiguous_list
)
*
chunk_ratio
...
...
@@ -1133,29 +1178,35 @@ class MemoryEstimator(object):
*
chunk_ratio
/
(
1024
**
2
)
)
# delete unused vars not in chunk_input_list
# we can't delete input nodes until chunk ends
if
chunk_within
:
act_memory
-=
self
.
_get_chunk_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_ratio
,
node_list
,
start_nodes
[
chunk_region_idx
],
end_nodes
[
chunk_region_idx
],
chunk_inputs_names
)
/
(
1024
**
2
)
else
:
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses_no_free_var
)
/
(
1024
**
2
)
act_memory
-=
(
self
.
_get_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_inputs_names
)
/
(
1024
**
2
)
)
# log active node
# log active node
, only effective without chunk
self
.
_add_active_node
(
node
,
active_node_list
)
self
.
_remove_deactive_node
(
node
,
user_to_last_uses
,
active_node_list
)
# if node in chunk end nodes, restore chunk settings
if
use_chunk
and
idx
in
end_node
s
:
if
use_chunk
and
idx
in
chunk_end
s
:
act_memory
-=
(
self
.
_get_output_node_size
(
node
)
*
chunk_ratio
/
(
1024
**
2
)
)
act_memory
-=
self
.
_get_chunk_inputs_size
(
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_non_chunk
[
chunk_region_idx
],
node_list
,
chunk_regions
[
chunk_region_idx
][
1
]
)
/
(
1024
**
2
)
chunk_within
=
False
chunk_ratio
=
1
chunk_region_idx
=
None
...
...
@@ -1178,11 +1229,11 @@ class ChunkRegionSearch(object):
def
__init__
(
self
,
gm
)
->
None
:
self
.
gm
=
gm
self
.
node_list
=
list
(
gm
.
graph
.
nodes
)
self
.
memory_estimator
=
MemoryEstimator
()
self
.
index_tracer
=
IndexTracer
(
gm
)
self
.
index_tracer
.
trace_index
()
self
.
flow_tracer
=
FlowTracer
(
gm
)
self
.
flow_tracer
.
trace_flow
()
self
.
memory_estimator
=
MemoryEstimator
(
self
.
index_tracer
)
def
_find_peak_node
(
self
,
mem_peak
):
max_value
=
max
(
mem_peak
)
...
...
@@ -1210,7 +1261,7 @@ class ChunkRegionSearch(object):
min_var
=
self
.
_get_min_free_var
(
active_node
,
free_vars
)
# from peak_node to free_var
chunk_region_start
=
None
chunk_region_start
=
len
(
free_vars
)
for
i
in
range
(
peak_node
,
-
1
,
-
1
):
if
len
(
active_node
[
i
])
==
min_var
:
chunk_region_start
=
i
+
1
...
...
@@ -1218,7 +1269,7 @@ class ChunkRegionSearch(object):
if
i
in
free_vars
or
i
==
0
:
raise
RuntimeError
()
# from peak_node to len-2
chunk_region_end
=
None
chunk_region_end
=
len
(
active_node
)
-
1
for
i
in
range
(
peak_node
,
len
(
active_node
)):
if
len
(
active_node
[
i
])
==
min_var
:
chunk_region_end
=
i
...
...
@@ -1352,7 +1403,7 @@ class ChunkRegionSearch(object):
return
False
def
search_region
(
self
):
chunk_
region
s
=
[]
chunk_
info
s
=
[]
(
init_mem_peak
,
_
,
...
...
@@ -1361,25 +1412,19 @@ class ChunkRegionSearch(object):
mem_peak
=
init_mem_peak
while
True
:
chunk_
region
=
self
.
_step_search
(
mem_peak
,
active_node
,
chunk_
region
s
)
if
chunk_
region
is
None
:
chunk_
info
=
self
.
_step_search
(
mem_peak
,
active_node
,
chunk_
info
s
)
if
chunk_
info
is
None
:
break
chunk_
region
s
.
append
(
chunk_
region
)
chunk_
info
s
.
append
(
chunk_
info
)
(
mem_peak
,
_
,
active_node
,
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
gm
,
[
i
[
"region"
][
0
]
for
i
in
chunk_regions
],
[
i
[
"region"
][
1
]
for
i
in
chunk_regions
],
[
i
[
"inputs_dim"
]
for
i
in
chunk_regions
],
[
1
]
*
len
(
chunk_regions
),
)
)
=
self
.
memory_estimator
.
estimate_chunk_inference_mem
(
self
.
gm
,
chunk_infos
)
if
self
.
_stop_search
(
init_mem_peak
,
mem_peak
):
break
return
chunk_
region
s
return
chunk_
info
s
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
...
...
@@ -1415,7 +1460,7 @@ def _gen_loop_end(
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
)
context
=
" chunk_result%s = %s
\n
"
%
(
chunk_slice
,
chunk_outputs_name
)
context
=
" chunk_result%s = %s
; %s = None
\n
"
%
(
chunk_slice
,
chunk_outputs_name
,
chunk_outputs_name
)
context
+=
(
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment