Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1bb1f2ad
Commit
1bb1f2ad
authored
Jan 09, 2023
by
oahzxl
Browse files
rename
parent
cb9817f7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
27 deletions
+27
-27
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+2
-2
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+25
-25
No files found.
colossalai/autochunk/search_chunk.py
View file @
1bb1f2ad
...
@@ -158,11 +158,11 @@ class SearchChunk(object):
...
@@ -158,11 +158,11 @@ class SearchChunk(object):
end_trace
=
output_trace
[
end_idx
]
end_trace
=
output_trace
[
end_idx
]
end_node
=
self
.
trace_indice
.
node_list
[
end_idx
]
end_node
=
self
.
trace_indice
.
node_list
[
end_idx
]
chunk_infos
=
[]
chunk_infos
=
[]
for
end_dim
,
_
in
enumerate
(
end_trace
[
"i
dx
"
]):
for
end_dim
,
_
in
enumerate
(
end_trace
[
"i
ndice
"
]):
if
len
(
start_traces
)
>
1
:
if
len
(
start_traces
)
>
1
:
continue
continue
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_node
,
start_trace
in
start_traces
.
items
():
for
start_dim
,
_
in
enumerate
(
start_trace
[
"i
dx
"
]):
for
start_dim
,
_
in
enumerate
(
start_trace
[
"i
ndice
"
]):
# dim size cannot be 1
# dim size cannot be 1
if
(
if
(
get_node_shape
(
end_node
)[
end_dim
]
==
1
get_node_shape
(
end_node
)[
end_dim
]
==
1
...
...
colossalai/autochunk/trace_indice.py
View file @
1bb1f2ad
...
@@ -19,12 +19,12 @@ class TraceIndice(object):
...
@@ -19,12 +19,12 @@ class TraceIndice(object):
for
n
in
self
.
node_list
:
for
n
in
self
.
node_list
:
if
get_node_shape
(
n
)
!=
None
:
if
get_node_shape
(
n
)
!=
None
:
cur_trace
=
{
cur_trace
=
{
"i
dx
"
:
[
None
for
_
in
range
(
len
(
get_node_shape
(
n
)))],
"i
ndice
"
:
[
None
for
_
in
range
(
len
(
get_node_shape
(
n
)))],
"compute"
:
[[]
for
_
in
range
(
len
(
get_node_shape
(
n
)))],
"compute"
:
[[]
for
_
in
range
(
len
(
get_node_shape
(
n
)))],
"source"
:
[{}
for
_
in
range
(
len
(
get_node_shape
(
n
)))],
"source"
:
[{}
for
_
in
range
(
len
(
get_node_shape
(
n
)))],
}
}
else
:
else
:
cur_trace
=
{
"i
dx
"
:
[],
"compute"
:
[],
"source"
:
[]}
cur_trace
=
{
"i
ndice
"
:
[],
"compute"
:
[],
"source"
:
[]}
indice_trace_list
.
append
(
cur_trace
)
indice_trace_list
.
append
(
cur_trace
)
return
indice_trace_list
return
indice_trace_list
...
@@ -39,12 +39,12 @@ class TraceIndice(object):
...
@@ -39,12 +39,12 @@ class TraceIndice(object):
return
self
.
indice_count
return
self
.
indice_count
def
_del_dim
(
self
,
idx
,
dim_idx
):
def
_del_dim
(
self
,
idx
,
dim_idx
):
self
.
indice_trace_list
[
idx
][
"i
dx
"
].
pop
(
dim_idx
)
self
.
indice_trace_list
[
idx
][
"i
ndice
"
].
pop
(
dim_idx
)
self
.
indice_trace_list
[
idx
][
"compute"
].
pop
(
dim_idx
)
self
.
indice_trace_list
[
idx
][
"compute"
].
pop
(
dim_idx
)
self
.
indice_trace_list
[
idx
][
"source"
].
pop
(
dim_idx
)
self
.
indice_trace_list
[
idx
][
"source"
].
pop
(
dim_idx
)
def
_add_dim
(
self
,
node_idx
,
dim_idx
):
def
_add_dim
(
self
,
node_idx
,
dim_idx
):
self
.
indice_trace_list
[
node_idx
][
"i
dx
"
].
insert
(
dim_idx
,
self
.
_add_indice
())
self
.
indice_trace_list
[
node_idx
][
"i
ndice
"
].
insert
(
dim_idx
,
self
.
_add_indice
())
self
.
indice_trace_list
[
node_idx
][
"compute"
].
insert
(
dim_idx
,
[])
self
.
indice_trace_list
[
node_idx
][
"compute"
].
insert
(
dim_idx
,
[])
self
.
indice_trace_list
[
node_idx
][
"source"
].
insert
(
dim_idx
,
{})
self
.
indice_trace_list
[
node_idx
][
"source"
].
insert
(
dim_idx
,
{})
...
@@ -58,7 +58,7 @@ class TraceIndice(object):
...
@@ -58,7 +58,7 @@ class TraceIndice(object):
node_to_dim
=
self
.
_transform_indice
(
node_to
,
node_to_dim
)
node_to_dim
=
self
.
_transform_indice
(
node_to
,
node_to_dim
)
node_from_trace
=
self
.
_find_trace_from_node
(
node_from
)
node_from_trace
=
self
.
_find_trace_from_node
(
node_from
)
node_to_trace
=
self
.
_find_trace_from_node
(
node_to
)
node_to_trace
=
self
.
_find_trace_from_node
(
node_to
)
node_to_trace
[
"i
dx
"
][
node_to_dim
]
=
node_from_trace
[
"i
dx
"
][
node_from_dim
]
node_to_trace
[
"i
ndice
"
][
node_to_dim
]
=
node_from_trace
[
"i
ndice
"
][
node_from_dim
]
node_to_trace
[
"compute"
][
node_to_dim
]
=
copy
.
deepcopy
(
node_to_trace
[
"compute"
][
node_to_dim
]
=
copy
.
deepcopy
(
node_from_trace
[
"compute"
][
node_from_dim
]
node_from_trace
[
"compute"
][
node_from_dim
]
)
)
...
@@ -181,7 +181,7 @@ class TraceIndice(object):
...
@@ -181,7 +181,7 @@ class TraceIndice(object):
idx (list): idx of the node
idx (list): idx of the node
"""
"""
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
return
self
.
indice_trace_list
[
node_idx
][
"i
dx
"
]
return
self
.
indice_trace_list
[
node_idx
][
"i
ndice
"
]
def
_find_compute_trace_from_node
(
self
,
node
):
def
_find_compute_trace_from_node
(
self
,
node
):
"""
"""
...
@@ -195,7 +195,7 @@ class TraceIndice(object):
...
@@ -195,7 +195,7 @@ class TraceIndice(object):
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
return
self
.
indice_trace_list
[
node_idx
][
"compute"
]
return
self
.
indice_trace_list
[
node_idx
][
"compute"
]
def
_assign_inde
x
_as_input
(
self
,
node
,
node_idx
,
input_node
=
None
):
def
_assign_ind
ic
e_as_input
(
self
,
node
,
node_idx
,
input_node
=
None
):
"""
"""
Assign node's trace as its input node.
Assign node's trace as its input node.
...
@@ -206,10 +206,10 @@ class TraceIndice(object):
...
@@ -206,10 +206,10 @@ class TraceIndice(object):
if
input_node
==
None
:
if
input_node
==
None
:
input_node
=
node
.
args
[
0
]
input_node
=
node
.
args
[
0
]
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
input_node_idx_trace
=
self
.
indice_trace_list
[
input_node_idx
][
"i
dx
"
]
input_node_idx_trace
=
self
.
indice_trace_list
[
input_node_idx
][
"i
ndice
"
]
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
self
.
indice_trace_list
[
node_idx
][
"i
dx
"
]
=
new_idx_trace
self
.
indice_trace_list
[
node_idx
][
"i
ndice
"
]
=
new_idx_trace
self
.
_inherit_all_computation
(
input_node
,
node
)
self
.
_inherit_all_computation
(
input_node
,
node
)
...
@@ -225,7 +225,7 @@ class TraceIndice(object):
...
@@ -225,7 +225,7 @@ class TraceIndice(object):
new_trace
=
[]
new_trace
=
[]
for
_
in
shape
:
for
_
in
shape
:
new_trace
.
append
(
self
.
_add_indice
())
new_trace
.
append
(
self
.
_add_indice
())
self
.
indice_trace_list
[
node_idx
][
"i
dx
"
]
=
new_trace
self
.
indice_trace_list
[
node_idx
][
"i
ndice
"
]
=
new_trace
def
_assign_transpose_indice
(
self
,
node
,
node_idx
):
def
_assign_transpose_indice
(
self
,
node
,
node_idx
):
"""
"""
...
@@ -240,7 +240,7 @@ class TraceIndice(object):
...
@@ -240,7 +240,7 @@ class TraceIndice(object):
input_node
=
node
.
args
[
0
]
input_node
=
node
.
args
[
0
]
tranpose_dim
=
node
.
args
[
1
:]
tranpose_dim
=
node
.
args
[
1
:]
self
.
_assign_inde
x
_as_input
(
node
,
node_idx
,
input_node
)
self
.
_assign_ind
ic
e_as_input
(
node
,
node_idx
,
input_node
)
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
1
],
node
,
tranpose_dim
[
0
])
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
1
],
node
,
tranpose_dim
[
0
])
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
0
],
node
,
tranpose_dim
[
1
])
self
.
_inherit_indice
(
input_node
,
tranpose_dim
[
0
],
node
,
tranpose_dim
[
1
])
...
@@ -257,7 +257,7 @@ class TraceIndice(object):
...
@@ -257,7 +257,7 @@ class TraceIndice(object):
permute_dim
=
node
.
args
[
1
:]
permute_dim
=
node
.
args
[
1
:]
input_node
=
node
.
args
[
0
]
input_node
=
node
.
args
[
0
]
self
.
_assign_inde
x
_as_input
(
node
,
node_idx
,
input_node
)
self
.
_assign_ind
ic
e_as_input
(
node
,
node_idx
,
input_node
)
for
idx
,
d
in
enumerate
(
permute_dim
):
for
idx
,
d
in
enumerate
(
permute_dim
):
self
.
_inherit_indice
(
input_node
,
d
,
node
,
idx
)
self
.
_inherit_indice
(
input_node
,
d
,
node
,
idx
)
...
@@ -278,7 +278,7 @@ class TraceIndice(object):
...
@@ -278,7 +278,7 @@ class TraceIndice(object):
else
:
else
:
input_node
,
weight
,
bias
=
node
.
args
input_node
,
weight
,
bias
=
node
.
args
self
.
_assign_inde
x
_as_input
(
node
,
node_idx
)
self
.
_assign_ind
ic
e_as_input
(
node
,
node_idx
)
self
.
_inherit_indice
(
weight
,
1
,
node
,
-
1
)
self
.
_inherit_indice
(
weight
,
1
,
node
,
-
1
)
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
...
@@ -301,7 +301,7 @@ class TraceIndice(object):
...
@@ -301,7 +301,7 @@ class TraceIndice(object):
matmul_left
,
matmul_right
=
node
.
args
matmul_left
,
matmul_right
=
node
.
args
assert
len
(
get_node_shape
(
matmul_left
))
==
len
(
get_node_shape
(
matmul_right
))
assert
len
(
get_node_shape
(
matmul_left
))
==
len
(
get_node_shape
(
matmul_right
))
self
.
_assign_inde
x
_as_input
(
node
,
node_idx
,
matmul_left
)
self
.
_assign_ind
ic
e_as_input
(
node
,
node_idx
,
matmul_left
)
self
.
_inherit_indice
(
matmul_right
,
-
1
,
node
,
-
1
)
self
.
_inherit_indice
(
matmul_right
,
-
1
,
node
,
-
1
)
self
.
_mark_computation_from_node
(
matmul_right
,
node
,
[
-
1
,
-
2
])
self
.
_mark_computation_from_node
(
matmul_right
,
node
,
[
-
1
,
-
2
])
...
@@ -318,7 +318,7 @@ class TraceIndice(object):
...
@@ -318,7 +318,7 @@ class TraceIndice(object):
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
self
.
_assign_inde
x
_as_input
(
node
,
idx
)
self
.
_assign_ind
ic
e_as_input
(
node
,
idx
)
self
.
_mark_computation
(
node
,
idx
,
[
-
1
])
self
.
_mark_computation
(
node
,
idx
,
[
-
1
])
def
_assign_elementwise_indice
(
self
,
node
,
idx
):
def
_assign_elementwise_indice
(
self
,
node
,
idx
):
...
@@ -331,7 +331,7 @@ class TraceIndice(object):
...
@@ -331,7 +331,7 @@ class TraceIndice(object):
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
self
.
_assign_inde
x
_as_input
(
node
,
idx
)
self
.
_assign_ind
ic
e_as_input
(
node
,
idx
)
nodes_in
=
[]
nodes_in
=
[]
for
node_in
in
node
.
args
:
for
node_in
in
node
.
args
:
if
type
(
node_in
)
==
type
(
node
):
if
type
(
node_in
)
==
type
(
node
):
...
@@ -346,7 +346,7 @@ class TraceIndice(object):
...
@@ -346,7 +346,7 @@ class TraceIndice(object):
self
.
_mark_indice_equal
(
nodes_in
[
0
],
i
,
nodes_in
[
1
],
i
)
self
.
_mark_indice_equal
(
nodes_in
[
0
],
i
,
nodes_in
[
1
],
i
)
def
_assgin_no_change_indice
(
self
,
node
,
idx
):
def
_assgin_no_change_indice
(
self
,
node
,
idx
):
self
.
_assign_inde
x
_as_input
(
node
,
idx
)
self
.
_assign_ind
ic
e_as_input
(
node
,
idx
)
for
node_in
in
node
.
args
:
for
node_in
in
node
.
args
:
if
type
(
node_in
)
==
type
(
node
):
if
type
(
node_in
)
==
type
(
node
):
self
.
_mark_computation_from_node
(
node_in
,
node
)
self
.
_mark_computation_from_node
(
node_in
,
node
)
...
@@ -398,7 +398,7 @@ class TraceIndice(object):
...
@@ -398,7 +398,7 @@ class TraceIndice(object):
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
self
.
_assign_inde
x
_as_input
(
node
,
idx
)
self
.
_assign_ind
ic
e_as_input
(
node
,
idx
)
self
.
_mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
"dim"
]])
self
.
_mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
"dim"
]])
def
_assign_unsqueeze_indice
(
self
,
node
,
node_idx
):
def
_assign_unsqueeze_indice
(
self
,
node
,
node_idx
):
...
@@ -411,7 +411,7 @@ class TraceIndice(object):
...
@@ -411,7 +411,7 @@ class TraceIndice(object):
node_idx (int)
node_idx (int)
"""
"""
self
.
_del_dim
(
node_idx
,
-
1
)
self
.
_del_dim
(
node_idx
,
-
1
)
self
.
_assign_inde
x
_as_input
(
node
,
node_idx
)
self
.
_assign_ind
ic
e_as_input
(
node
,
node_idx
)
self
.
_add_dim
(
node_idx
,
node
.
args
[
1
])
self
.
_add_dim
(
node_idx
,
node
.
args
[
1
])
def
_assign_dropout_indice
(
self
,
node
,
node_idx
):
def
_assign_dropout_indice
(
self
,
node
,
node_idx
):
...
@@ -423,7 +423,7 @@ class TraceIndice(object):
...
@@ -423,7 +423,7 @@ class TraceIndice(object):
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
self
.
_assign_inde
x
_as_input
(
node
,
node_idx
)
self
.
_assign_ind
ic
e_as_input
(
node
,
node_idx
)
def
_assign_ones_like_indice
(
self
,
node
,
node_idx
):
def
_assign_ones_like_indice
(
self
,
node
,
node_idx
):
"""
"""
...
@@ -497,7 +497,7 @@ class TraceIndice(object):
...
@@ -497,7 +497,7 @@ class TraceIndice(object):
# get new index
# get new index
origin_trace
=
self
.
_find_indice_trace_from_node
(
origin_node
)
origin_trace
=
self
.
_find_indice_trace_from_node
(
origin_node
)
self
.
_assign_inde
x
_as_input
(
node
,
node_idx
,
origin_node
)
self
.
_assign_ind
ic
e_as_input
(
node
,
node_idx
,
origin_node
)
dim_from
.
reverse
()
dim_from
.
reverse
()
for
i
in
dim_from
:
for
i
in
dim_from
:
self
.
_del_dim
(
node_idx
,
i
)
self
.
_del_dim
(
node_idx
,
i
)
...
@@ -516,7 +516,7 @@ class TraceIndice(object):
...
@@ -516,7 +516,7 @@ class TraceIndice(object):
view_dict
=
{
view_dict
=
{
"idx_from"
:
[
origin_trace
[
i
]
for
i
in
dim_from
],
"idx_from"
:
[
origin_trace
[
i
]
for
i
in
dim_from
],
"dim_from"
:
dim_from
,
"dim_from"
:
dim_from
,
"idx_to"
:
[
self
.
indice_trace_list
[
node_idx
][
"i
dx
"
][
i
]
for
i
in
dim_to
],
"idx_to"
:
[
self
.
indice_trace_list
[
node_idx
][
"i
ndice
"
][
i
]
for
i
in
dim_to
],
"dim_to"
:
dim_to
,
"dim_to"
:
dim_to
,
}
}
self
.
indice_view_list
[
node
]
=
view_dict
self
.
indice_view_list
[
node
]
=
view_dict
...
@@ -528,9 +528,9 @@ class TraceIndice(object):
...
@@ -528,9 +528,9 @@ class TraceIndice(object):
merge_to
=
min
(
idx
)
merge_to
=
min
(
idx
)
merge_from
=
max
(
idx
)
merge_from
=
max
(
idx
)
for
trace
in
self
.
indice_trace_list
:
for
trace
in
self
.
indice_trace_list
:
if
merge_from
in
trace
[
"i
dx
"
]:
if
merge_from
in
trace
[
"i
ndice
"
]:
trace
[
"i
dx
"
]
=
[
trace
[
"i
ndice
"
]
=
[
merge_to
if
i
==
merge_from
else
i
for
i
in
trace
[
"i
dx
"
]
merge_to
if
i
==
merge_from
else
i
for
i
in
trace
[
"i
ndice
"
]
]
]
def
trace_index
(
self
):
def
trace_index
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment