Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ded10056
Commit
ded10056
authored
Dec 21, 2022
by
oahzxl
Browse files
format code
parent
d361d533
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
122 additions
and
62 deletions
+122
-62
chunk_codegen.py
chunk_codegen.py
+122
-62
No files found.
chunk_codegen.py
View file @
ded10056
...
@@ -144,7 +144,9 @@ class IndexTracer(object):
...
@@ -144,7 +144,9 @@ class IndexTracer(object):
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
]
=
[
node_from_dim
]
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
]
=
[
node_from_dim
]
else
:
else
:
if
node_from_dim
not
in
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
]:
if
node_from_dim
not
in
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
]:
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
].
append
(
node_from_dim
)
node_to_trace
[
"source"
][
node_to_dim
][
node_from_idx
].
append
(
node_from_dim
)
# update inputs source
# update inputs source
node_to_trace
[
"source"
][
node_to_dim
].
update
(
node_to_trace
[
"source"
][
node_to_dim
].
update
(
node_from_trace
[
"source"
][
node_from_dim
]
node_from_trace
[
"source"
][
node_from_dim
]
...
@@ -745,7 +747,6 @@ class IndexTracer(object):
...
@@ -745,7 +747,6 @@ class IndexTracer(object):
return
True
return
True
class
FlowTracer
(
object
):
class
FlowTracer
(
object
):
def
__init__
(
self
,
gm
)
->
None
:
def
__init__
(
self
,
gm
)
->
None
:
self
.
gm
=
gm
self
.
gm
=
gm
...
@@ -856,7 +857,9 @@ class FlowTracer(object):
...
@@ -856,7 +857,9 @@ class FlowTracer(object):
)
)
return
self
.
flow_trace
return
self
.
flow_trace
def
_detect_flow
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
,
index_tracer
:
IndexTracer
):
def
_detect_flow
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
,
index_tracer
:
IndexTracer
):
inputs
,
outputs
=
_find_chunk_compute_input_and_output_nodes
(
inputs
,
outputs
=
_find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
)
...
@@ -946,7 +949,9 @@ class FlowTracer(object):
...
@@ -946,7 +949,9 @@ class FlowTracer(object):
if
i
in
chunk_info
[
"inputs"
]:
if
i
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs"
].
remove
(
i
)
chunk_info
[
"inputs"
].
remove
(
i
)
duplicate_result
,
duplicate_dim
=
index_tracer
.
check_index_duplicate
(
chunk_info
,
return_dim
=
True
)
duplicate_result
,
duplicate_dim
=
index_tracer
.
check_index_duplicate
(
chunk_info
,
return_dim
=
True
)
# we need to log input nodes to avoid deleteing them in the loop
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
...
@@ -958,10 +963,20 @@ class FlowTracer(object):
...
@@ -958,10 +963,20 @@ class FlowTracer(object):
return
flow_block
,
chunk_info
return
flow_block
,
chunk_info
def
_assgin_single_node_flow
(
self
,
arg_node
,
start_idx
,
end_idx
,
def
_assgin_single_node_flow
(
inputs
,
index_tracer
,
cur_node_dim
,
self
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
arg_node
,
next_node_list
):
start_idx
,
end_idx
,
inputs
,
index_tracer
,
cur_node_dim
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
next_node_list
,
):
arg_idx
=
_find_idx_by_name
(
arg_node
.
name
,
index_tracer
.
nodes_list
)
arg_idx
=
_find_idx_by_name
(
arg_node
.
name
,
index_tracer
.
nodes_list
)
# arg in chunk range or be inputs
# arg in chunk range or be inputs
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
if
not
(
start_idx
<=
arg_idx
<
end_idx
):
...
@@ -991,15 +1006,19 @@ class FlowTracer(object):
...
@@ -991,15 +1006,19 @@ class FlowTracer(object):
if
arg_node
in
all_node_info
:
if
arg_node
in
all_node_info
:
if
all_node_info
[
arg_node
]
!=
arg_dim
:
if
all_node_info
[
arg_node
]
!=
arg_dim
:
return
False
return
False
all_node_info
[
arg_node
][
'fix_dim'
]
=
list
(
set
(
all_node_info
[
arg_node
][
'fix_dim'
]
+
arg_fix_dim
))
all_node_info
[
arg_node
][
"fix_dim"
]
=
list
(
set
(
all_node_info
[
arg_node
][
"fix_dim"
]
+
arg_fix_dim
)
)
# else add it to list
# else add it to list
else
:
else
:
all_node_info
[
arg_node
]
=
{
'
chunk_dim
'
:
arg_dim
,
'
fix_dim
'
:
arg_fix_dim
}
all_node_info
[
arg_node
]
=
{
"
chunk_dim
"
:
arg_dim
,
"
fix_dim
"
:
arg_fix_dim
}
next_node_list
.
append
(
arg_node
)
next_node_list
.
append
(
arg_node
)
return
True
return
True
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
,
index_tracer
:
IndexTracer
):
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
,
index_tracer
:
IndexTracer
):
inputs
,
outputs
=
_find_chunk_compute_input_and_output_nodes
(
inputs
,
outputs
=
_find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
)
...
@@ -1008,19 +1027,23 @@ class FlowTracer(object):
...
@@ -1008,19 +1027,23 @@ class FlowTracer(object):
return
None
return
None
cur_node_list
=
[
index_tracer
.
nodes_list
[
end_idx
]]
# start from the last node
cur_node_list
=
[
index_tracer
.
nodes_list
[
end_idx
]]
# start from the last node
all_node_info
=
{
cur_node_list
[
0
]:
{
'
chunk_dim
'
:
end_dim
,
'
fix_dim
'
:
[]}}
all_node_info
=
{
cur_node_list
[
0
]:
{
"
chunk_dim
"
:
end_dim
,
"
fix_dim
"
:
[]}}
while
len
(
cur_node_list
)
>
0
:
while
len
(
cur_node_list
)
>
0
:
next_node_list
=
[]
next_node_list
=
[]
for
cur_node
in
cur_node_list
:
for
cur_node
in
cur_node_list
:
# get cur node info
# get cur node info
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
'
chunk_dim
'
]
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"
chunk_dim
"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
'
fix_dim
'
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"
fix_dim
"
]
cur_node_idx
=
_find_idx_by_name
(
cur_node
.
name
,
index_tracer
.
nodes_list
)
cur_node_idx
=
_find_idx_by_name
(
cur_node
.
name
,
index_tracer
.
nodes_list
)
if
cur_node_chunk_dim
:
if
cur_node_chunk_dim
:
cur_node_compute
=
index_tracer
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_compute
=
index_tracer
.
_find_compute_trace_from_node
(
cur_node_source
=
index_tracer
.
_find_source_trace_from_node
(
cur_node
)
cur_node
)
cur_node_source
=
index_tracer
.
_find_source_trace_from_node
(
cur_node
)
else
:
else
:
cur_node_compute
=
cur_node_source
=
None
cur_node_compute
=
cur_node_source
=
None
...
@@ -1032,20 +1055,33 @@ class FlowTracer(object):
...
@@ -1032,20 +1055,33 @@ class FlowTracer(object):
if
_is_non_compute_node
(
arg
):
if
_is_non_compute_node
(
arg
):
continue
continue
arg_list
.
append
(
arg
)
arg_list
.
append
(
arg
)
flow_flag
=
self
.
_assgin_single_node_flow
(
arg
,
start_idx
,
end_idx
,
flow_flag
=
self
.
_assgin_single_node_flow
(
inputs
,
index_tracer
,
cur_node_chunk_dim
,
arg
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
start_idx
,
next_node_list
)
end_idx
,
inputs
,
index_tracer
,
cur_node_chunk_dim
,
cur_node_compute
,
cur_node_source
,
cur_node_fix_dim
,
all_node_info
,
next_node_list
,
)
if
flow_flag
==
False
:
if
flow_flag
==
False
:
return
None
return
None
if
len
(
arg_list
)
==
2
:
if
len
(
arg_list
)
==
2
:
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
]):
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
]):
for
arg
in
arg_list
:
for
arg
in
arg_list
:
if
not
(
start_idx
<=
_find_idx_by_name
(
arg
.
name
,
index_tracer
.
nodes_list
)
<
end_idx
):
if
not
(
start_idx
<=
_find_idx_by_name
(
arg
.
name
,
index_tracer
.
nodes_list
)
<
end_idx
):
continue
continue
arg_chunk_dim
=
all_node_info
[
arg
][
'
chunk_dim
'
]
arg_chunk_dim
=
all_node_info
[
arg
][
"
chunk_dim
"
]
arg_fix_dim
=
all_node_info
[
arg
][
'
fix_dim
'
]
arg_fix_dim
=
all_node_info
[
arg
][
"
fix_dim
"
]
arg_shape
=
_get_node_shape
(
arg
)
arg_shape
=
_get_node_shape
(
arg
)
# add all dim as fix dim except chunk dim
# add all dim as fix dim except chunk dim
for
i
,
shape
in
enumerate
(
arg_shape
):
for
i
,
shape
in
enumerate
(
arg_shape
):
...
@@ -1071,7 +1107,7 @@ class FlowTracer(object):
...
@@ -1071,7 +1107,7 @@ class FlowTracer(object):
continue
continue
user_idx
=
_find_idx_by_name
(
user
.
name
,
self
.
node_list
)
user_idx
=
_find_idx_by_name
(
user
.
name
,
self
.
node_list
)
if
start_idx
<=
user_idx
<=
end_idx
:
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
'
chunk_dim
'
]
chunk_dim
=
all_node_info
[
user
][
"
chunk_dim
"
]
if
chunk_dim
is
not
None
:
if
chunk_dim
is
not
None
:
input_dict
[
user_idx
]
=
chunk_dim
input_dict
[
user_idx
]
=
chunk_dim
if
len
(
input_dict
)
==
0
:
if
len
(
input_dict
)
==
0
:
...
@@ -1129,7 +1165,7 @@ class MemoryEstimator(object):
...
@@ -1129,7 +1165,7 @@ class MemoryEstimator(object):
def
_add_active_node
(
self
,
n
,
active_list
):
def
_add_active_node
(
self
,
n
,
active_list
):
new_active
=
self
.
_get_output_node
(
n
)[
1
]
new_active
=
self
.
_get_output_node
(
n
)[
1
]
if
n
.
op
==
'
placeholder
'
:
if
n
.
op
==
"
placeholder
"
:
new_active
.
append
(
n
.
name
)
new_active
.
append
(
n
.
name
)
for
i
in
new_active
:
for
i
in
new_active
:
if
i
not
in
active_list
:
if
i
not
in
active_list
:
...
@@ -1169,11 +1205,15 @@ class MemoryEstimator(object):
...
@@ -1169,11 +1205,15 @@ class MemoryEstimator(object):
if
i
in
active_list
:
if
i
in
active_list
:
active_list
.
remove
(
i
)
active_list
.
remove
(
i
)
def
_get_chunk_inputs_size
(
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
def
_get_chunk_inputs_size
(
self
,
chunk_inputs
,
chunk_inputs_non_chunk
,
node_list
,
chunk_end_idx
):
nodes_to_delete
=
[]
nodes_to_delete
=
[]
for
chunk_input
in
chunk_inputs
+
chunk_inputs_non_chunk
:
for
chunk_input
in
chunk_inputs
+
chunk_inputs_non_chunk
:
chunk_input_users
=
chunk_input
.
users
.
keys
()
chunk_input_users
=
chunk_input
.
users
.
keys
()
chunk_input_users_idx
=
[
_find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
chunk_input_users_idx
=
[
_find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_input_users
]
if
all
(
i
<=
chunk_end_idx
for
i
in
chunk_input_users_idx
):
if
all
(
i
<=
chunk_end_idx
for
i
in
chunk_input_users_idx
):
if
chunk_input
not
in
nodes_to_delete
:
if
chunk_input
not
in
nodes_to_delete
:
nodes_to_delete
.
append
(
chunk_input
)
nodes_to_delete
.
append
(
chunk_input
)
...
@@ -1226,7 +1266,9 @@ class MemoryEstimator(object):
...
@@ -1226,7 +1266,9 @@ class MemoryEstimator(object):
for
(
input_node
,
input_node_dim
)
in
zip
(
chunk_inputs
,
chunk_inputs_dim
):
for
(
input_node
,
input_node_dim
)
in
zip
(
chunk_inputs
,
chunk_inputs_dim
):
for
k
,
v
in
input_node_dim
.
items
():
for
k
,
v
in
input_node_dim
.
items
():
# TODO: inherit dim should be list too, int now
# TODO: inherit dim should be list too, int now
inherit_dim
=
self
.
index_tracer
.
_find_inherit_dim
(
input_node
,
v
,
self
.
index_tracer
.
nodes_list
[
k
])
inherit_dim
=
self
.
index_tracer
.
_find_inherit_dim
(
input_node
,
v
,
self
.
index_tracer
.
nodes_list
[
k
]
)
if
k
==
_find_idx_by_name
(
node
.
name
,
self
.
index_tracer
.
nodes_list
):
if
k
==
_find_idx_by_name
(
node
.
name
,
self
.
index_tracer
.
nodes_list
):
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
inherit_dim
]
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
inherit_dim
]
return
chunk_ratio
return
chunk_ratio
...
@@ -1234,7 +1276,7 @@ class MemoryEstimator(object):
...
@@ -1234,7 +1276,7 @@ class MemoryEstimator(object):
if
k
in
source
and
inherit_dim
in
source
[
k
]:
if
k
in
source
and
inherit_dim
in
source
[
k
]:
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
dim
]
chunk_ratio
=
float
(
chunk_size
)
/
node_shape
[
dim
]
return
chunk_ratio
return
chunk_ratio
return
1.
return
1.
0
def
_get_chunk_delete_node_size
(
def
_get_chunk_delete_node_size
(
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
self
,
user
,
user_to_last_uses
,
chunk_ratio
,
chunk_inputs_names
...
@@ -1313,12 +1355,17 @@ class MemoryEstimator(object):
...
@@ -1313,12 +1355,17 @@ class MemoryEstimator(object):
if
use_chunk
and
idx
in
chunk_starts
:
if
use_chunk
and
idx
in
chunk_starts
:
chunk_within
=
True
chunk_within
=
True
chunk_region_idx
=
chunk_starts
.
index
(
idx
)
chunk_region_idx
=
chunk_starts
.
index
(
idx
)
act_memory
+=
self
.
_get_output_node_size
(
chunk_outputs
[
chunk_region_idx
])
/
(
1024
**
2
)
act_memory
+=
self
.
_get_output_node_size
(
chunk_outputs
[
chunk_region_idx
]
)
/
(
1024
**
2
)
# determine chunk ratio for current node
# determine chunk ratio for current node
if
chunk_within
:
if
chunk_within
:
chunk_ratio
=
self
.
_get_chunk_ratio
(
chunk_ratio
=
self
.
_get_chunk_ratio
(
node
,
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_dim
[
chunk_region_idx
],
chunk_size
node
,
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_dim
[
chunk_region_idx
],
chunk_size
,
)
)
# if node is placeholder, just add the size of the node
# if node is placeholder, just add the size of the node
...
@@ -1359,12 +1406,12 @@ class MemoryEstimator(object):
...
@@ -1359,12 +1406,12 @@ class MemoryEstimator(object):
node
,
node
,
user_to_last_uses_no_free_var
,
user_to_last_uses_no_free_var
,
chunk_ratio
,
chunk_ratio
,
chunk_inputs_names
chunk_inputs_names
,
)
/
(
1024
**
2
)
)
/
(
1024
**
2
)
else
:
else
:
act_memory
-=
(
self
.
_get_delete_node_size
(
act_memory
-=
self
.
_get_delete_node_size
(
node
,
user_to_last_uses_no_free_var
,
chunk_inputs_names
node
,
user_to_last_uses_no_free_var
,
chunk_inputs_names
)
/
(
1024
**
2
)
)
)
/
(
1024
**
2
)
# log active node, only effective without chunk
# log active node, only effective without chunk
self
.
_add_active_node
(
node
,
active_node_list
)
self
.
_add_active_node
(
node
,
active_node_list
)
...
@@ -1379,7 +1426,7 @@ class MemoryEstimator(object):
...
@@ -1379,7 +1426,7 @@ class MemoryEstimator(object):
chunk_inputs
[
chunk_region_idx
],
chunk_inputs
[
chunk_region_idx
],
chunk_inputs_non_chunk
[
chunk_region_idx
],
chunk_inputs_non_chunk
[
chunk_region_idx
],
node_list
,
node_list
,
chunk_regions
[
chunk_region_idx
][
1
]
chunk_regions
[
chunk_region_idx
][
1
]
,
)
/
(
1024
**
2
)
)
/
(
1024
**
2
)
chunk_within
=
False
chunk_within
=
False
chunk_ratio
=
1
chunk_ratio
=
1
...
@@ -1494,7 +1541,12 @@ class ChunkRegionSearch(object):
...
@@ -1494,7 +1541,12 @@ class ChunkRegionSearch(object):
continue
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
start_trace_idx
in
enumerate
(
start_trace
[
"idx"
]):
for
start_dim
,
start_trace_idx
in
enumerate
(
start_trace
[
"idx"
]):
if
start_idx
==
199
and
end_idx
==
229
and
start_dim
==
2
and
end_dim
==
2
:
if
(
start_idx
==
199
and
end_idx
==
229
and
start_dim
==
2
and
end_dim
==
2
):
print
(
1
)
print
(
1
)
self
.
flow_tracer
.
flow_search
(
self
.
flow_tracer
.
flow_search
(
start_idx
,
start_dim
,
end_idx
,
end_dim
,
self
.
index_tracer
start_idx
,
start_dim
,
end_idx
,
end_dim
,
self
.
index_tracer
...
@@ -1585,8 +1637,10 @@ class ChunkRegionSearch(object):
...
@@ -1585,8 +1637,10 @@ class ChunkRegionSearch(object):
return
False
return
False
for
i
in
chunk_infos
:
for
i
in
chunk_infos
:
region
=
i
[
"region"
]
region
=
i
[
"region"
]
if
not
((
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
if
not
(
or
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])):
(
chunk_region_start
>
region
[
1
]
and
chunk_region_end
>
region
[
1
])
or
(
chunk_region_start
<
region
[
0
]
and
chunk_region_end
<
region
[
0
])
):
return
False
return
False
return
True
return
True
...
@@ -1600,7 +1654,9 @@ class ChunkRegionSearch(object):
...
@@ -1600,7 +1654,9 @@ class ChunkRegionSearch(object):
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
possible_chunk_regions
=
self
.
_search_possible_chunk_regions
(
max_chunk_region
,
peak_node
max_chunk_region
,
peak_node
)
)
best_chunk_region
=
self
.
_search_best_chunk_region
(
possible_chunk_regions
,
chunk_regions
)
best_chunk_region
=
self
.
_search_best_chunk_region
(
possible_chunk_regions
,
chunk_regions
)
return
best_chunk_region
return
best_chunk_region
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
def
_stop_search
(
self
,
init_mem_peak
,
mem_peak
):
...
@@ -1667,7 +1723,11 @@ def _gen_loop_end(
...
@@ -1667,7 +1723,11 @@ def _gen_loop_end(
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
)
)
context
=
" chunk_result%s = %s; %s = None
\n
"
%
(
chunk_slice
,
chunk_outputs_name
,
chunk_outputs_name
)
context
=
" chunk_result%s = %s; %s = None
\n
"
%
(
chunk_slice
,
chunk_outputs_name
,
chunk_outputs_name
,
)
context
+=
(
context
+=
(
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment