Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2b4ebcc2
"tests/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "0160a62a3c8f2b4bef5edd9997037fba69bf0da7"
Commit
2b4ebcc2
authored
Dec 08, 2022
by
oahzxl
Browse files
finishi codegen on msa
parent
6d99994a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
188 additions
and
24 deletions
+188
-24
chunk_codegen.py
chunk_codegen.py
+188
-24
No files found.
chunk_codegen.py
View file @
2b4ebcc2
...
@@ -17,6 +17,121 @@ def _delete_free_var_from_last_use(user_to_last_uses):
...
@@ -17,6 +17,121 @@ def _delete_free_var_from_last_use(user_to_last_uses):
user_to_last_uses
[
key
].
remove
(
n
)
user_to_last_uses
[
key
].
remove
(
n
)
class
FlowTracer
(
object
):
def
__init__
(
self
,
gm
)
->
None
:
self
.
gm
=
gm
self
.
nodes_list
=
list
(
gm
.
graph
.
nodes
)
self
.
flow_trace
=
{}
def
_add_trace
(
self
,
name
):
self
.
flow_trace
[
name
]
=
[]
def
_add_node
(
self
,
trace_name
,
node
):
self
.
flow_trace
[
trace_name
].
append
({
'node'
:
node
,
'inside_depend'
:
[],
'outside_depend'
:
[]})
def
_add_inside_depend
(
self
,
flow_name
,
node
,
inside_depend_node
):
for
i
in
self
.
flow_trace
[
flow_name
]:
if
i
[
'node'
]
==
node
:
i
[
'inside_depend'
].
append
(
inside_depend_node
)
return
raise
RuntimeError
(
"node not found"
)
def
_add_outside_depend
(
self
,
flow_name
,
node
,
outside_depend_node
,
outside_depend_trace
):
for
i
in
self
.
flow_trace
[
flow_name
]:
if
i
[
'node'
]
==
node
:
i
[
'outside_depend'
].
append
({
outside_depend_trace
:
outside_depend_node
})
return
raise
RuntimeError
(
"node not found"
)
def
_init_trace
(
self
):
for
i
in
self
.
nodes_list
:
if
i
.
op
==
'placeholder'
:
self
.
_add_trace
(
i
.
name
)
self
.
_add_node
(
i
.
name
,
i
)
def
_is_non_compute_node
(
self
,
node
):
if
any
(
i
in
node
.
op
for
i
in
[
'placeholder'
,
'get_attr'
,
'output'
])
or
\
any
(
i
in
node
.
name
for
i
in
[
'getitem'
,
'getattr'
]):
return
True
return
False
def
_is_non_compute_node_except_placeholder
(
self
,
node
):
if
any
(
i
in
node
.
op
for
i
in
[
'get_attr'
,
'output'
])
or
\
any
(
i
in
node
.
name
for
i
in
[
'getitem'
,
'getattr'
]):
return
True
return
False
def
_find_flow_for_node
(
self
,
node
):
if
type
(
self
.
nodes_list
[
0
])
!=
type
(
node
):
return
None
if
self
.
_is_non_compute_node_except_placeholder
(
node
):
return
None
for
name
,
trace
in
self
.
flow_trace
.
items
():
for
i
in
trace
:
if
node
==
i
[
'node'
]:
return
name
if
any
(
i
in
node
.
name
for
i
in
[
"ones_like"
]):
self
.
_add_trace
(
node
.
name
)
self
.
_add_node
(
node
.
name
,
node
)
return
node
.
name
raise
RuntimeError
(
"node not found"
)
def
_find_first_valid_flow
(
self
,
flow
):
for
i
in
flow
:
if
i
is
not
None
:
return
i
raise
RuntimeError
(
"invalid flow"
)
def
find_node_flow
(
self
,
node
):
for
name
,
trace
in
self
.
flow_trace
.
items
():
for
i
in
trace
:
if
node
==
i
[
'node'
]:
return
name
,
i
raise
RuntimeError
(
"invalid node"
)
def
get_flow_mix
(
self
,
node
):
if
self
.
_is_non_compute_node
(
node
):
return
None
_
,
node_trace
=
self
.
find_node_flow
(
node
)
if
len
(
node_trace
[
'outside_depend'
])
==
0
:
return
None
elif
len
(
node_trace
[
'outside_depend'
])
>
1
:
raise
NotImplementedError
vars
=
list
(
node_trace
[
'outside_depend'
][
0
].
values
())[
0
]
return
vars
def
get_same_flow_node
(
self
,
node_list
,
node
):
name
,
_
=
self
.
find_node_flow
(
node
)
result
=
[]
for
i
in
self
.
flow_trace
[
name
]:
if
i
[
'node'
]
in
node_list
:
result
.
append
(
i
[
'node'
])
return
result
def
trace_flow
(
self
):
# init trace
self
.
_init_trace
()
for
node
in
self
.
nodes_list
:
# skip if non compute node
if
all
(
type
(
arg
)
!=
type
(
node
)
or
self
.
_is_non_compute_node_except_placeholder
(
arg
)
for
arg
in
node
.
args
)
\
or
self
.
_is_non_compute_node
(
node
):
continue
node_input_flows
=
[
self
.
_find_flow_for_node
(
arg
)
for
arg
in
node
.
args
]
node_domin_flow
=
self
.
_find_first_valid_flow
(
node_input_flows
)
self
.
_add_node
(
node_domin_flow
,
node
)
for
node_input_flow
,
arg
in
zip
(
node_input_flows
,
node
.
args
):
if
node_input_flow
is
None
:
continue
elif
node_input_flow
==
node_domin_flow
:
self
.
_add_inside_depend
(
node_domin_flow
,
node
,
arg
)
else
:
self
.
_add_outside_depend
(
node_domin_flow
,
node
,
arg
,
node_input_flow
)
return
self
.
flow_trace
class
IndexTracer
(
object
):
class
IndexTracer
(
object
):
def
__init__
(
self
,
gm
)
->
None
:
def
__init__
(
self
,
gm
)
->
None
:
self
.
gm
=
gm
self
.
gm
=
gm
...
@@ -428,7 +543,7 @@ class IndexTracer(object):
...
@@ -428,7 +543,7 @@ class IndexTracer(object):
if
merge_from
in
trace
[
'idx'
]:
if
merge_from
in
trace
[
'idx'
]:
trace
[
'idx'
]
=
[
merge_to
if
i
==
merge_from
else
i
for
i
in
trace
[
'idx'
]]
trace
[
'idx'
]
=
[
merge_to
if
i
==
merge_from
else
i
for
i
in
trace
[
'idx'
]]
def
trace_
node_id
x
(
self
):
def
trace_
inde
x
(
self
):
for
idx
,
node
in
enumerate
(
self
.
nodes_list
):
for
idx
,
node
in
enumerate
(
self
.
nodes_list
):
if
node
.
op
==
'placeholder'
:
if
node
.
op
==
'placeholder'
:
self
.
_assign_all_index
(
node
,
idx
)
self
.
_assign_all_index
(
node
,
idx
)
...
@@ -684,7 +799,9 @@ class ChunkRegionSearch(object):
...
@@ -684,7 +799,9 @@ class ChunkRegionSearch(object):
self
.
node_list
=
list
(
gm
.
graph
.
nodes
)
self
.
node_list
=
list
(
gm
.
graph
.
nodes
)
self
.
memory_estimator
=
MemoryEstimator
()
self
.
memory_estimator
=
MemoryEstimator
()
self
.
index_tracer
=
IndexTracer
(
gm
)
self
.
index_tracer
=
IndexTracer
(
gm
)
self
.
index_tracer
.
trace_node_idx
()
self
.
index_tracer
.
trace_index
()
self
.
flow_tracer
=
FlowTracer
(
gm
)
self
.
flow_tracer
.
trace_flow
()
def
_find_peak_node
(
self
,
mem_peak
):
def
_find_peak_node
(
self
,
mem_peak
):
max_value
=
max
(
mem_peak
)
max_value
=
max
(
mem_peak
)
...
@@ -729,7 +846,7 @@ class ChunkRegionSearch(object):
...
@@ -729,7 +846,7 @@ class ChunkRegionSearch(object):
raise
RuntimeError
()
raise
RuntimeError
()
return
chunk_region_start
,
chunk_region_end
return
chunk_region_start
,
chunk_region_end
def
_not_compute
(
self
,
trace
,
chunk_range
,
dim_idx
):
def
_is
_not_compute
(
self
,
trace
,
chunk_range
,
dim_idx
):
if
trace
[
'idx'
][
dim_idx
]
not
in
trace
[
'compute'
]:
if
trace
[
'idx'
][
dim_idx
]
not
in
trace
[
'compute'
]:
return
True
return
True
if
trace
[
'idx'
][
dim_idx
]
in
trace
[
'compute'
]
and
\
if
trace
[
'idx'
][
dim_idx
]
in
trace
[
'compute'
]
and
\
...
@@ -737,6 +854,56 @@ class ChunkRegionSearch(object):
...
@@ -737,6 +854,56 @@ class ChunkRegionSearch(object):
return
True
return
True
return
False
return
False
def
_detect_flow
(
self
,
before_trace
,
after_trace
,
start_idx
,
end_idx
,
dim_idx
):
inputs
,
outputs
=
_find_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
])
chunk_info
=
{
'inputs'
:
inputs
,
'outputs'
:
outputs
}
flow_flag
=
False
for
idx
in
range
(
start_idx
,
end_idx
+
1
):
node
=
self
.
node_list
[
idx
]
mix_flow_var
=
self
.
flow_tracer
.
get_flow_mix
(
node
)
if
mix_flow_var
is
None
:
continue
# if there is a flow mix, op must be in [mul, add, div, matmul]
# element-wise op requires dim to be equal in every dim
if
any
(
n
in
node
.
name
for
n
in
[
'mul'
,
'add'
]):
for
i
in
node
.
args
:
if
type
(
i
)
==
type
(
mix_flow_var
)
and
i
!=
mix_flow_var
:
main_flow_var
=
i
# if mix flow is a broadcast in chunk dim,
# TODO need to move that flow out of the chunk
if
mix_flow_var
.
meta
[
'tensor_meta'
].
shape
[
dim_idx
]
==
1
:
flow_flag
=
True
for
i
in
self
.
flow_tracer
.
get_same_flow_node
(
chunk_info
[
'inputs'
],
mix_flow_var
):
chunk_info
[
'inputs'
].
remove
(
i
)
# else, we need to chunk mix var as well
else
:
# TODO chunk another value
flow_flag
=
False
break
else
:
raise
NotImplementedError
(
"%s not implemented"
%
node
.
name
)
return
flow_flag
,
chunk_info
def
_find_free_dim
(
self
,
input_trace
,
output_trace
,
start_idx
,
end_idx
):
before_trace
=
input_trace
[
start_idx
]
after_trace
=
output_trace
[
end_idx
]
free_dim
=
[]
chunk_infos
=
[]
for
i
in
range
(
min
(
len
(
before_trace
[
'idx'
]),
len
(
after_trace
[
'idx'
]))):
if
not
(
before_trace
[
'idx'
][
i
]
==
after_trace
[
'idx'
][
i
]
and
self
.
_is_not_compute
(
before_trace
,
(
start_idx
,
end_idx
),
i
)
and
self
.
_is_not_compute
(
after_trace
,
(
start_idx
,
end_idx
),
i
)
and
self
.
node_list
[
end_idx
].
meta
[
'tensor_meta'
].
shape
[
i
]
!=
1
):
continue
flow_flag
,
chunk_info
=
self
.
_detect_flow
(
before_trace
,
after_trace
,
start_idx
,
end_idx
,
i
)
if
flow_flag
==
None
:
continue
chunk_infos
.
append
(
chunk_info
)
free_dim
.
append
(
i
)
return
free_dim
,
chunk_infos
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
,
peak_node
):
def
_search_possible_chunk_regions
(
self
,
max_chunk_region
,
peak_node
):
possible_chunk_region
=
[]
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
index_tracer
.
idx_trace_list
)
output_trace
=
copy
.
deepcopy
(
self
.
index_tracer
.
idx_trace_list
)
...
@@ -748,27 +915,22 @@ class ChunkRegionSearch(object):
...
@@ -748,27 +915,22 @@ class ChunkRegionSearch(object):
else
:
else
:
input_trace
.
append
(
None
)
input_trace
.
append
(
None
)
for
before
_idx
in
range
(
max_chunk_region
[
0
],
peak_node
):
for
start
_idx
in
range
(
max_chunk_region
[
0
],
peak_node
):
for
after
_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
for
end
_idx
in
range
(
peak_node
,
max_chunk_region
[
1
]
+
1
):
# skip non compute nodes
# skip non compute nodes
if
any
(
op
in
[
'placeholder'
,
'get_attr'
,
'output'
]
for
op
in
if
any
(
op
in
[
'placeholder'
,
'get_attr'
,
'output'
]
for
op
in
[
self
.
node_list
[
before
_idx
].
op
,
self
.
node_list
[
after
_idx
].
op
]):
[
self
.
node_list
[
start
_idx
].
op
,
self
.
node_list
[
end
_idx
].
op
]):
continue
continue
if
any
(
any
(
i
in
name
for
i
in
[
'getitem'
,
'getattr'
])
for
name
in
if
any
(
any
(
i
in
name
for
i
in
[
'getitem'
,
'getattr'
])
for
name
in
[
self
.
node_list
[
before
_idx
].
name
,
self
.
node_list
[
after
_idx
].
name
]):
[
self
.
node_list
[
start
_idx
].
name
,
self
.
node_list
[
end
_idx
].
name
]):
continue
continue
# select free dim
# select free dim
before_trace
=
input_trace
[
before_idx
]
free_dim
,
chunk_info
=
self
.
_find_free_dim
(
input_trace
,
output_trace
,
start_idx
,
end_idx
)
after_trace
=
output_trace
[
after_idx
]
if
len
(
free_dim
)
>
0
:
free_dim
=
[]
free_dim
=
[
free_dim
[
0
]]
for
i
in
range
(
min
(
len
(
before_trace
[
'idx'
]),
len
(
after_trace
[
'idx'
]))):
chunk_info
=
[
chunk_info
[
0
]]
if
(
before_trace
[
'idx'
][
i
]
==
after_trace
[
'idx'
][
i
]
and
possible_chunk_region
.
append
({
'region'
:
(
start_idx
,
end_idx
),
'dim'
:
free_dim
,
'chunk_info'
:
chunk_info
})
self
.
_not_compute
(
before_trace
,
(
before_idx
,
after_idx
),
i
)
and
self
.
_not_compute
(
after_trace
,
(
before_idx
,
after_idx
),
i
)
and
self
.
node_list
[
after_idx
].
meta
[
'tensor_meta'
].
shape
[
i
]
!=
1
):
free_dim
.
append
(
i
)
possible_chunk_region
.
append
({
'region'
:
(
before_idx
,
after_idx
),
'dim'
:
free_dim
})
return
possible_chunk_region
return
possible_chunk_region
def
_search_best_chunk_region
(
self
,
possible_chunk_regions
):
def
_search_best_chunk_region
(
self
,
possible_chunk_regions
):
...
@@ -935,21 +1097,23 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -935,21 +1097,23 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
chunk_search
=
chunk_region_search
.
search_region
()
chunk_search
=
chunk_region_search
.
search_region
()
chunk_regions
=
[
i
[
'region'
]
for
i
in
chunk_search
]
chunk_regions
=
[
i
[
'region'
]
for
i
in
chunk_search
]
chunk_dims
=
[
i
[
'dim'
]
for
i
in
chunk_search
]
chunk_dims
=
[
i
[
'dim'
]
for
i
in
chunk_search
]
chunk_infos
=
[
i
[
'chunk_info'
]
for
i
in
chunk_search
]
chunk_starts
=
[
item
[
0
]
for
item
in
chunk_regions
]
chunk_starts
=
[
item
[
0
]
for
item
in
chunk_regions
]
chunk_ends
=
[
item
[
1
]
for
item
in
chunk_regions
]
chunk_ends
=
[
item
[
1
]
for
item
in
chunk_regions
]
chunk_inputs
=
[]
chunk_inputs
=
[
[
j
[
'inputs'
][
0
]
for
j
in
i
]
for
i
in
chunk_infos
]
chunk_outputs
=
[]
chunk_outputs
=
[
[
j
[
'outputs'
][
0
]
for
j
in
i
]
for
i
in
chunk_infos
]
within_chunk_region
=
False
within_chunk_region
=
False
node_list
=
list
(
nodes
)
node_list
=
list
(
nodes
)
# find the input and output var names for each offload region
# find the input and output var names for each offload region
for
idx
,
(
start
,
end
)
in
enumerate
(
chunk_regions
):
# for idx, (start, end) in enumerate(chunk_regions):
offload_node_list
=
node_list
[
start
:
end
+
1
]
# offload_node_list = node_list[start:end + 1]
inputs
,
outputs
=
_find_input_and_output_nodes
(
offload_node_list
)
# inputs, outputs = _find_input_and_output_nodes(offload_node_list)
chunk_inputs
.
append
(
inputs
)
# chunk_inputs.append(inputs)
chunk_outputs
.
append
(
outputs
)
# chunk_outputs.append(outputs)
chunk_inputs_idx
=
[[
_find_idx_by_name
(
j
.
name
,
node_list
)
for
j
in
i
]
for
i
in
chunk_inputs
]
chunk_inputs_idx
=
[[
_find_idx_by_name
(
j
.
name
,
node_list
)
for
j
in
i
]
for
i
in
chunk_inputs
]
chunk_outputs_idx
=
[[
_find_idx_by_name
(
j
.
name
,
node_list
)
for
j
in
i
]
for
i
in
chunk_outputs
]
chunk_outputs_idx
=
[[
_find_idx_by_name
(
j
.
name
,
node_list
)
for
j
in
i
]
for
i
in
chunk_outputs
]
chunk_inputs_names
=
[]
chunk_inputs_names
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment