Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4748967f
Commit
4748967f
authored
Jan 06, 2023
by
oahzxl
Browse files
ad reorder graph
parent
da407684
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
108 additions
and
106 deletions
+108
-106
colossalai/autochunk/reorder_graph.py
colossalai/autochunk/reorder_graph.py
+108
-0
colossalai/autochunk/trace_index.py
colossalai/autochunk/trace_index.py
+0
-106
No files found.
colossalai/autochunk/reorder_graph.py
0 → 100644
View file @
4748967f
from
.trace_index
import
TraceIndex
from
.utils
import
find_idx_by_name
class
ReorderGraph
(
object
):
def
__init__
(
self
,
index_tracer
:
TraceIndex
)
->
None
:
self
.
index_tracer
=
index_tracer
self
.
all_reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
index_tracer
.
idx_trace_list
))}
def
_get_reorder_map
(
self
,
chunk_info
):
reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
index_tracer
.
node_list
))}
chunk_region_start
=
chunk_info
[
"region"
][
0
]
chunk_region_end
=
chunk_info
[
"region"
][
1
]
chunk_prepose_nodes
=
chunk_info
[
"args"
][
"prepose_nodes"
]
chunk_prepose_nodes_idx
=
[
find_idx_by_name
(
i
.
name
,
self
.
index_tracer
.
node_list
)
for
i
in
chunk_prepose_nodes
]
# put prepose nodes ahead
for
idx
,
n
in
enumerate
(
chunk_prepose_nodes
):
n_idx
=
chunk_prepose_nodes_idx
[
idx
]
reorder_map
[
n_idx
]
=
chunk_region_start
+
idx
# put other nodes after prepose nodes
for
n
in
self
.
index_tracer
.
node_list
[
chunk_region_start
:
chunk_region_end
+
1
]:
if
n
in
chunk_prepose_nodes
:
continue
n_idx
=
find_idx_by_name
(
n
.
name
,
self
.
index_tracer
.
node_list
)
pos
=
sum
([
n_idx
<
i
for
i
in
chunk_prepose_nodes_idx
])
reorder_map
[
n_idx
]
=
n_idx
+
pos
return
reorder_map
def
_reorder_chunk_info
(
self
,
chunk_info
,
reorder_map
):
# update chunk info
chunk_info
[
"region"
]
=
(
chunk_info
[
"region"
][
0
]
+
len
(
chunk_info
[
"args"
][
"prepose_nodes"
]),
chunk_info
[
"region"
][
1
],
)
new_inputs_dim
=
[]
for
idx
,
input_dim
in
enumerate
(
chunk_info
[
"inputs_dim"
]):
new_input_dim
=
{}
for
k
,
v
in
input_dim
.
items
():
new_input_dim
[
reorder_map
[
k
]]
=
v
new_inputs_dim
.
append
(
new_input_dim
)
chunk_info
[
"inputs_dim"
]
=
new_inputs_dim
return
chunk_info
def
_update_all_reorder_map
(
self
,
reorder_map
):
for
origin_idx
,
map_idx
in
self
.
all_reorder_map
.
items
():
self
.
all_reorder_map
[
origin_idx
]
=
reorder_map
[
map_idx
]
def
_reorder_self_node_list
(
self
,
reorder_map
):
new_node_list
=
[
None
for
_
in
range
(
len
(
self
.
index_tracer
.
node_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_node_list
[
new_idx
]
=
self
.
index_tracer
.
node_list
[
old_idx
]
self
.
index_tracer
.
node_list
=
new_node_list
def
_reorder_idx_trace
(
self
,
reorder_map
):
# reorder list
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
index_tracer
.
idx_trace_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_idx_trace_list
[
new_idx
]
=
self
.
index_tracer
.
idx_trace_list
[
old_idx
]
self
.
index_tracer
.
idx_trace_list
=
new_idx_trace_list
# update compute
for
idx_trace
in
self
.
index_tracer
.
idx_trace_list
:
compute
=
idx_trace
[
"compute"
]
for
dim_compute
in
compute
:
for
idx
,
i
in
enumerate
(
dim_compute
):
dim_compute
[
idx
]
=
reorder_map
[
i
]
# update source
for
idx_trace
in
self
.
index_tracer
.
idx_trace_list
:
source
=
idx_trace
[
"source"
]
for
dim_idx
,
dim_source
in
enumerate
(
source
):
new_dim_source
=
{}
for
k
,
v
in
dim_source
.
items
():
new_dim_source
[
reorder_map
[
k
]]
=
v
source
[
dim_idx
]
=
new_dim_source
def
reorder_all
(
self
,
chunk_info
):
if
chunk_info
is
None
:
return
chunk_info
if
len
(
chunk_info
[
"args"
][
"prepose_nodes"
])
==
0
:
return
chunk_info
reorder_map
=
self
.
_get_reorder_map
(
chunk_info
)
self
.
_update_all_reorder_map
(
reorder_map
)
self
.
_reorder_idx_trace
(
reorder_map
)
self
.
_reorder_self_node_list
(
reorder_map
)
chunk_info
=
self
.
_reorder_chunk_info
(
chunk_info
,
reorder_map
)
return
chunk_info
def
reorder_node_list
(
self
,
node_list
):
new_node_list
=
[
None
for
_
in
range
(
len
(
node_list
))]
for
old_idx
,
new_idx
in
self
.
all_reorder_map
.
items
():
new_node_list
[
new_idx
]
=
node_list
[
old_idx
]
return
new_node_list
def
tmp_reorder
(
self
,
node_list
,
chunk_info
):
if
len
(
chunk_info
[
"args"
][
"prepose_nodes"
])
==
0
:
return
node_list
,
chunk_info
reorder_map
=
self
.
_get_reorder_map
(
chunk_info
)
# new tmp node list
new_node_list
=
[
None
for
_
in
range
(
len
(
node_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_node_list
[
new_idx
]
=
node_list
[
old_idx
]
chunk_info
=
self
.
_reorder_chunk_info
(
chunk_info
,
reorder_map
)
return
new_node_list
,
chunk_info
colossalai/autochunk/trace_index.py
View file @
4748967f
...
@@ -979,109 +979,3 @@ class TraceIndex(object):
...
@@ -979,109 +979,3 @@ class TraceIndex(object):
)
)
chunk_info
[
"reshape_size"
]
=
reshape_size
chunk_info
[
"reshape_size"
]
=
reshape_size
return
chunk_info
return
chunk_info
class
ReorderGraph
(
object
):
def
__init__
(
self
,
index_tracer
:
TraceIndex
)
->
None
:
self
.
index_tracer
=
index_tracer
self
.
all_reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
index_tracer
.
idx_trace_list
))}
def
_get_reorder_map
(
self
,
chunk_info
):
reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
index_tracer
.
node_list
))}
chunk_region_start
=
chunk_info
[
"region"
][
0
]
chunk_region_end
=
chunk_info
[
"region"
][
1
]
chunk_prepose_nodes
=
chunk_info
[
"args"
][
"prepose_nodes"
]
chunk_prepose_nodes_idx
=
[
find_idx_by_name
(
i
.
name
,
self
.
index_tracer
.
node_list
)
for
i
in
chunk_prepose_nodes
]
# put prepose nodes ahead
for
idx
,
n
in
enumerate
(
chunk_prepose_nodes
):
n_idx
=
chunk_prepose_nodes_idx
[
idx
]
reorder_map
[
n_idx
]
=
chunk_region_start
+
idx
# put other nodes after prepose nodes
for
n
in
self
.
index_tracer
.
node_list
[
chunk_region_start
:
chunk_region_end
+
1
]:
if
n
in
chunk_prepose_nodes
:
continue
n_idx
=
find_idx_by_name
(
n
.
name
,
self
.
index_tracer
.
node_list
)
pos
=
sum
([
n_idx
<
i
for
i
in
chunk_prepose_nodes_idx
])
reorder_map
[
n_idx
]
=
n_idx
+
pos
return
reorder_map
def
_reorder_chunk_info
(
self
,
chunk_info
,
reorder_map
):
# update chunk info
chunk_info
[
"region"
]
=
(
chunk_info
[
"region"
][
0
]
+
len
(
chunk_info
[
"args"
][
"prepose_nodes"
]),
chunk_info
[
"region"
][
1
],
)
new_inputs_dim
=
[]
for
idx
,
input_dim
in
enumerate
(
chunk_info
[
"inputs_dim"
]):
new_input_dim
=
{}
for
k
,
v
in
input_dim
.
items
():
new_input_dim
[
reorder_map
[
k
]]
=
v
new_inputs_dim
.
append
(
new_input_dim
)
chunk_info
[
"inputs_dim"
]
=
new_inputs_dim
return
chunk_info
def
_update_all_reorder_map
(
self
,
reorder_map
):
for
origin_idx
,
map_idx
in
self
.
all_reorder_map
.
items
():
self
.
all_reorder_map
[
origin_idx
]
=
reorder_map
[
map_idx
]
def
_reorder_self_node_list
(
self
,
reorder_map
):
new_node_list
=
[
None
for
_
in
range
(
len
(
self
.
index_tracer
.
node_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_node_list
[
new_idx
]
=
self
.
index_tracer
.
node_list
[
old_idx
]
self
.
index_tracer
.
node_list
=
new_node_list
def
_reorder_idx_trace
(
self
,
reorder_map
):
# reorder list
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
index_tracer
.
idx_trace_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_idx_trace_list
[
new_idx
]
=
self
.
index_tracer
.
idx_trace_list
[
old_idx
]
self
.
index_tracer
.
idx_trace_list
=
new_idx_trace_list
# update compute
for
idx_trace
in
self
.
index_tracer
.
idx_trace_list
:
compute
=
idx_trace
[
"compute"
]
for
dim_compute
in
compute
:
for
idx
,
i
in
enumerate
(
dim_compute
):
dim_compute
[
idx
]
=
reorder_map
[
i
]
# update source
for
idx_trace
in
self
.
index_tracer
.
idx_trace_list
:
source
=
idx_trace
[
"source"
]
for
dim_idx
,
dim_source
in
enumerate
(
source
):
new_dim_source
=
{}
for
k
,
v
in
dim_source
.
items
():
new_dim_source
[
reorder_map
[
k
]]
=
v
source
[
dim_idx
]
=
new_dim_source
def
reorder_all
(
self
,
chunk_info
):
if
chunk_info
is
None
:
return
chunk_info
if
len
(
chunk_info
[
"args"
][
"prepose_nodes"
])
==
0
:
return
chunk_info
reorder_map
=
self
.
_get_reorder_map
(
chunk_info
)
self
.
_update_all_reorder_map
(
reorder_map
)
self
.
_reorder_idx_trace
(
reorder_map
)
self
.
_reorder_self_node_list
(
reorder_map
)
chunk_info
=
self
.
_reorder_chunk_info
(
chunk_info
,
reorder_map
)
return
chunk_info
def
reorder_node_list
(
self
,
node_list
):
new_node_list
=
[
None
for
_
in
range
(
len
(
node_list
))]
for
old_idx
,
new_idx
in
self
.
all_reorder_map
.
items
():
new_node_list
[
new_idx
]
=
node_list
[
old_idx
]
return
new_node_list
def
tmp_reorder
(
self
,
node_list
,
chunk_info
):
if
len
(
chunk_info
[
"args"
][
"prepose_nodes"
])
==
0
:
return
node_list
,
chunk_info
reorder_map
=
self
.
_get_reorder_map
(
chunk_info
)
# new tmp node list
new_node_list
=
[
None
for
_
in
range
(
len
(
node_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_node_list
[
new_idx
]
=
node_list
[
old_idx
]
chunk_info
=
self
.
_reorder_chunk_info
(
chunk_info
,
reorder_map
)
return
new_node_list
,
chunk_info
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment