Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d734529a
Commit
d734529a
authored
Dec 21, 2022
by
oahzxl
Browse files
move flow tracer
parent
9d516fa6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
207 additions
and
206 deletions
+207
-206
chunk_codegen.py
chunk_codegen.py
+207
-206
No files found.
chunk_codegen.py
View file @
d734529a
...
@@ -64,212 +64,6 @@ def _is_non_compute_node_except_placeholder_output(node):
...
@@ -64,212 +64,6 @@ def _is_non_compute_node_except_placeholder_output(node):
return
False
return
False
class
FlowTracer
(
object
):
def
__init__
(
self
,
gm
)
->
None
:
self
.
gm
=
gm
self
.
node_list
=
list
(
gm
.
graph
.
nodes
)
self
.
flow_trace
=
{}
def
_add_trace
(
self
,
name
):
self
.
flow_trace
[
name
]
=
[]
def
_add_node
(
self
,
trace_name
,
node
):
self
.
flow_trace
[
trace_name
].
append
(
{
"node"
:
node
,
"inside_depend"
:
[],
"outside_depend"
:
[]}
)
def
_add_inside_depend
(
self
,
flow_name
,
node
,
inside_depend_node
):
for
i
in
self
.
flow_trace
[
flow_name
]:
if
i
[
"node"
]
==
node
:
i
[
"inside_depend"
].
append
(
inside_depend_node
)
return
raise
RuntimeError
(
"node not found"
)
def
_add_outside_depend
(
self
,
flow_name
,
node
,
outside_depend_node
,
outside_depend_trace
):
for
i
in
self
.
flow_trace
[
flow_name
]:
if
i
[
"node"
]
==
node
:
i
[
"outside_depend"
].
append
({
outside_depend_trace
:
outside_depend_node
})
return
raise
RuntimeError
(
"node not found"
)
def
_init_trace
(
self
):
for
i
in
self
.
node_list
:
if
i
.
op
==
"placeholder"
:
self
.
_add_trace
(
i
.
name
)
self
.
_add_node
(
i
.
name
,
i
)
def
_find_flow_for_node
(
self
,
node
):
if
type
(
self
.
node_list
[
0
])
!=
type
(
node
):
return
None
if
_is_non_compute_node_except_placeholder
(
node
):
return
None
for
name
,
trace
in
self
.
flow_trace
.
items
():
for
i
in
trace
:
if
node
==
i
[
"node"
]:
return
name
if
any
(
i
in
node
.
name
for
i
in
[
"ones_like"
]):
self
.
_add_trace
(
node
.
name
)
self
.
_add_node
(
node
.
name
,
node
)
return
node
.
name
raise
RuntimeError
(
"node not found"
)
def
_find_first_valid_flow
(
self
,
flow
):
for
i
in
flow
:
if
i
is
not
None
:
return
i
raise
RuntimeError
(
"invalid flow"
)
def
find_node_flow
(
self
,
node
):
for
name
,
trace
in
self
.
flow_trace
.
items
():
for
i
in
trace
:
if
node
==
i
[
"node"
]:
return
name
,
i
raise
RuntimeError
(
"invalid node"
)
def
_get_flow_mix_node
(
self
,
node
):
if
_is_non_compute_node
(
node
):
return
None
_
,
node_trace
=
self
.
find_node_flow
(
node
)
if
len
(
node_trace
[
"outside_depend"
])
==
0
:
return
None
elif
len
(
node_trace
[
"outside_depend"
])
>
1
:
raise
NotImplementedError
vars
=
list
(
node_trace
[
"outside_depend"
][
0
].
values
())[
0
]
return
vars
def
_get_same_flow_node
(
self
,
node_list
,
node
):
name
,
_
=
self
.
find_node_flow
(
node
)
result
=
[]
for
i
in
self
.
flow_trace
[
name
]:
if
i
[
"node"
]
in
node_list
:
result
.
append
(
i
[
"node"
])
return
result
def
trace_flow
(
self
):
# init trace
self
.
_init_trace
()
for
node
in
self
.
node_list
:
# skip if non compute node
if
all
(
type
(
arg
)
!=
type
(
node
)
or
_is_non_compute_node_except_placeholder
(
arg
)
for
arg
in
node
.
args
)
or
_is_non_compute_node
(
node
):
continue
node_input_flows
=
[
self
.
_find_flow_for_node
(
arg
)
for
arg
in
node
.
args
]
node_domin_flow
=
self
.
_find_first_valid_flow
(
node_input_flows
)
self
.
_add_node
(
node_domin_flow
,
node
)
for
node_input_flow
,
arg
in
zip
(
node_input_flows
,
node
.
args
):
if
node_input_flow
is
None
:
continue
elif
node_input_flow
==
node_domin_flow
:
self
.
_add_inside_depend
(
node_domin_flow
,
node
,
arg
)
else
:
self
.
_add_outside_depend
(
node_domin_flow
,
node
,
arg
,
node_input_flow
)
return
self
.
flow_trace
def
_detect_flow
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
,
index_tracer
):
inputs
,
outputs
=
_find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
chunk_info
=
{
"region"
:
(
start_idx
,
end_idx
),
"inputs"
:
inputs
,
"inputs_non_chunk"
:
[],
"inputs_dim"
:
start_dim
,
"outputs"
:
outputs
,
"outputs_dim"
:
end_dim
,
"args"
:
{},
}
flow_block
=
False
# TODO don't allow multi outputs now
if
len
(
outputs
)
>
1
:
flow_block
=
True
return
flow_block
,
chunk_info
for
idx
in
range
(
start_idx
,
end_idx
+
1
):
node
=
self
.
node_list
[
idx
]
mix_flow_node
=
self
.
_get_flow_mix_node
(
node
)
if
mix_flow_node
is
None
:
continue
# if there is a flow mix, op must be in [mul, add, matmul]
# element-wise op requires dim to be equal in every dim
if
any
(
n
in
node
.
name
for
n
in
[
"mul"
,
"add"
]):
for
i
in
node
.
args
:
if
type
(
i
)
==
type
(
mix_flow_node
)
and
i
!=
mix_flow_node
:
main_flow_var
=
i
# if mix flow is a broadcast in chunk dim,
# TODO: need to move that flow out of the chunk
mix_flow_node_dim
=
index_tracer
.
get_node_chunk_dim
(
self
.
node_list
[
end_idx
],
end_dim
,
node
)
if
mix_flow_node_dim
is
None
:
flow_block
=
True
break
if
_get_node_shape
(
mix_flow_node
)[
mix_flow_node_dim
]
==
1
:
flow_block
=
False
for
i
in
self
.
_get_same_flow_node
(
chunk_info
[
"inputs"
],
mix_flow_node
):
chunk_info
[
"inputs"
].
remove
(
i
)
# else, we need to chunk mix var as well
else
:
# TODO chunk another value
flow_block
=
True
break
else
:
raise
NotImplementedError
(
"%s not implemented"
%
node
.
name
)
if
flow_block
:
flow_block
=
True
return
flow_block
,
chunk_info
inputs_dim
=
[]
remove_inputs
=
[]
for
input_node
in
chunk_info
[
"inputs"
]:
input_dict
=
{}
for
user
in
input_node
.
users
.
keys
():
if
_is_non_compute_node
(
user
):
continue
user_idx
=
_find_idx_by_name
(
user
.
name
,
self
.
node_list
)
dim
=
None
if
start_dim
<=
user_idx
<
end_idx
:
dim
=
index_tracer
.
get_node_chunk_dim
(
self
.
node_list
[
end_idx
],
end_dim
,
input_node
)
elif
user_idx
==
end_idx
:
dim
=
end_dim
# n has relation with chunk dim
if
dim
is
not
None
and
_get_node_shape
(
user
)[
dim
]
!=
1
:
input_dict
[
user_idx
]
=
dim
if
len
(
input_dict
)
==
0
:
remove_inputs
.
append
(
input_node
)
else
:
inputs_dim
.
append
(
input_dict
)
chunk_info
[
"inputs_dim"
]
=
inputs_dim
for
i
in
remove_inputs
:
if
i
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs"
].
remove
(
i
)
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
for
i
in
non_chunk_inputs
:
if
i
not
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
return
flow_block
,
chunk_info
class
IndexTracer
(
object
):
class
IndexTracer
(
object
):
def
__init__
(
self
,
gm
)
->
None
:
def
__init__
(
self
,
gm
)
->
None
:
self
.
gm
=
gm
self
.
gm
=
gm
...
@@ -932,6 +726,213 @@ class IndexTracer(object):
...
@@ -932,6 +726,213 @@ class IndexTracer(object):
return
True
return
True
class
FlowTracer
(
object
):
def
__init__
(
self
,
gm
)
->
None
:
self
.
gm
=
gm
self
.
node_list
=
list
(
gm
.
graph
.
nodes
)
self
.
flow_trace
=
{}
def
_add_trace
(
self
,
name
):
self
.
flow_trace
[
name
]
=
[]
def
_add_node
(
self
,
trace_name
,
node
):
self
.
flow_trace
[
trace_name
].
append
(
{
"node"
:
node
,
"inside_depend"
:
[],
"outside_depend"
:
[]}
)
def
_add_inside_depend
(
self
,
flow_name
,
node
,
inside_depend_node
):
for
i
in
self
.
flow_trace
[
flow_name
]:
if
i
[
"node"
]
==
node
:
i
[
"inside_depend"
].
append
(
inside_depend_node
)
return
raise
RuntimeError
(
"node not found"
)
def
_add_outside_depend
(
self
,
flow_name
,
node
,
outside_depend_node
,
outside_depend_trace
):
for
i
in
self
.
flow_trace
[
flow_name
]:
if
i
[
"node"
]
==
node
:
i
[
"outside_depend"
].
append
({
outside_depend_trace
:
outside_depend_node
})
return
raise
RuntimeError
(
"node not found"
)
def
_init_trace
(
self
):
for
i
in
self
.
node_list
:
if
i
.
op
==
"placeholder"
:
self
.
_add_trace
(
i
.
name
)
self
.
_add_node
(
i
.
name
,
i
)
def
_find_flow_for_node
(
self
,
node
):
if
type
(
self
.
node_list
[
0
])
!=
type
(
node
):
return
None
if
_is_non_compute_node_except_placeholder
(
node
):
return
None
for
name
,
trace
in
self
.
flow_trace
.
items
():
for
i
in
trace
:
if
node
==
i
[
"node"
]:
return
name
if
any
(
i
in
node
.
name
for
i
in
[
"ones_like"
]):
self
.
_add_trace
(
node
.
name
)
self
.
_add_node
(
node
.
name
,
node
)
return
node
.
name
raise
RuntimeError
(
"node not found"
)
def
_find_first_valid_flow
(
self
,
flow
):
for
i
in
flow
:
if
i
is
not
None
:
return
i
raise
RuntimeError
(
"invalid flow"
)
def
find_node_flow
(
self
,
node
):
for
name
,
trace
in
self
.
flow_trace
.
items
():
for
i
in
trace
:
if
node
==
i
[
"node"
]:
return
name
,
i
raise
RuntimeError
(
"invalid node"
)
def
_get_flow_mix_node
(
self
,
node
):
if
_is_non_compute_node
(
node
):
return
None
_
,
node_trace
=
self
.
find_node_flow
(
node
)
if
len
(
node_trace
[
"outside_depend"
])
==
0
:
return
None
elif
len
(
node_trace
[
"outside_depend"
])
>
1
:
raise
NotImplementedError
vars
=
list
(
node_trace
[
"outside_depend"
][
0
].
values
())[
0
]
return
vars
def
_get_same_flow_node
(
self
,
node_list
,
node
):
name
,
_
=
self
.
find_node_flow
(
node
)
result
=
[]
for
i
in
self
.
flow_trace
[
name
]:
if
i
[
"node"
]
in
node_list
:
result
.
append
(
i
[
"node"
])
return
result
def
trace_flow
(
self
):
# init trace
self
.
_init_trace
()
for
node
in
self
.
node_list
:
# skip if non compute node
if
all
(
type
(
arg
)
!=
type
(
node
)
or
_is_non_compute_node_except_placeholder
(
arg
)
for
arg
in
node
.
args
)
or
_is_non_compute_node
(
node
):
continue
node_input_flows
=
[
self
.
_find_flow_for_node
(
arg
)
for
arg
in
node
.
args
]
node_domin_flow
=
self
.
_find_first_valid_flow
(
node_input_flows
)
self
.
_add_node
(
node_domin_flow
,
node
)
for
node_input_flow
,
arg
in
zip
(
node_input_flows
,
node
.
args
):
if
node_input_flow
is
None
:
continue
elif
node_input_flow
==
node_domin_flow
:
self
.
_add_inside_depend
(
node_domin_flow
,
node
,
arg
)
else
:
self
.
_add_outside_depend
(
node_domin_flow
,
node
,
arg
,
node_input_flow
)
return
self
.
flow_trace
def
_detect_flow
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
,
index_tracer
:
IndexTracer
):
inputs
,
outputs
=
_find_chunk_compute_input_and_output_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
chunk_info
=
{
"region"
:
(
start_idx
,
end_idx
),
"inputs"
:
inputs
,
"inputs_non_chunk"
:
[],
"inputs_dim"
:
start_dim
,
"outputs"
:
outputs
,
"outputs_dim"
:
end_dim
,
"args"
:
{},
}
flow_block
=
False
# TODO don't allow multi outputs now
if
len
(
outputs
)
>
1
:
flow_block
=
True
return
flow_block
,
chunk_info
for
idx
in
range
(
start_idx
,
end_idx
+
1
):
node
=
self
.
node_list
[
idx
]
mix_flow_node
=
self
.
_get_flow_mix_node
(
node
)
if
mix_flow_node
is
None
:
continue
# if there is a flow mix, op must be in [mul, add, matmul]
# element-wise op requires dim to be equal in every dim
if
any
(
n
in
node
.
name
for
n
in
[
"mul"
,
"add"
]):
for
i
in
node
.
args
:
if
type
(
i
)
==
type
(
mix_flow_node
)
and
i
!=
mix_flow_node
:
main_flow_var
=
i
# if mix flow is a broadcast in chunk dim,
# TODO: need to move that flow out of the chunk
mix_flow_node_dim
=
index_tracer
.
get_node_chunk_dim
(
self
.
node_list
[
end_idx
],
end_dim
,
node
)
if
mix_flow_node_dim
is
None
:
flow_block
=
True
break
if
_get_node_shape
(
mix_flow_node
)[
mix_flow_node_dim
]
==
1
:
flow_block
=
False
for
i
in
self
.
_get_same_flow_node
(
chunk_info
[
"inputs"
],
mix_flow_node
):
chunk_info
[
"inputs"
].
remove
(
i
)
# else, we need to chunk mix var as well
else
:
# TODO chunk another value
flow_block
=
True
break
else
:
raise
NotImplementedError
(
"%s not implemented"
%
node
.
name
)
if
flow_block
:
flow_block
=
True
return
flow_block
,
chunk_info
inputs_dim
=
[]
remove_inputs
=
[]
for
input_node
in
chunk_info
[
"inputs"
]:
input_dict
=
{}
for
user
in
input_node
.
users
.
keys
():
if
_is_non_compute_node
(
user
):
continue
user_idx
=
_find_idx_by_name
(
user
.
name
,
self
.
node_list
)
dim
=
None
if
start_dim
<=
user_idx
<
end_idx
:
dim
=
index_tracer
.
get_node_chunk_dim
(
self
.
node_list
[
end_idx
],
end_dim
,
input_node
)
elif
user_idx
==
end_idx
:
dim
=
end_dim
# n has relation with chunk dim
if
dim
is
not
None
and
_get_node_shape
(
user
)[
dim
]
!=
1
:
input_dict
[
user_idx
]
=
dim
if
len
(
input_dict
)
==
0
:
remove_inputs
.
append
(
input_node
)
else
:
inputs_dim
.
append
(
input_dict
)
chunk_info
[
"inputs_dim"
]
=
inputs_dim
for
i
in
remove_inputs
:
if
i
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs"
].
remove
(
i
)
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
for
i
in
non_chunk_inputs
:
if
i
not
in
chunk_info
[
"inputs"
]:
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
return
flow_block
,
chunk_info
class
MemoryEstimator
(
object
):
class
MemoryEstimator
(
object
):
def
__init__
(
self
,
index_tracer
:
IndexTracer
)
->
None
:
def
__init__
(
self
,
index_tracer
:
IndexTracer
)
->
None
:
self
.
index_tracer
=
index_tracer
self
.
index_tracer
=
index_tracer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment