Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cb9817f7
Commit
cb9817f7
authored
Jan 09, 2023
by
oahzxl
Browse files
rename function from index to indice
parent
0ea903b9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
91 additions
and
91 deletions
+91
-91
colossalai/autochunk/reorder_graph.py
colossalai/autochunk/reorder_graph.py
+6
-6
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+1
-1
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+1
-1
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+83
-83
No files found.
colossalai/autochunk/reorder_graph.py
View file @
cb9817f7
...
@@ -6,7 +6,7 @@ class ReorderGraph(object):
...
@@ -6,7 +6,7 @@ class ReorderGraph(object):
def
__init__
(
self
,
trace_indice
:
TraceIndice
)
->
None
:
def
__init__
(
self
,
trace_indice
:
TraceIndice
)
->
None
:
self
.
trace_indice
=
trace_indice
self
.
trace_indice
=
trace_indice
self
.
all_reorder_map
=
{
self
.
all_reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
trace_indice
.
i
dx
_trace_list
))
i
:
i
for
i
in
range
(
len
(
self
.
trace_indice
.
i
ndice
_trace_list
))
}
}
def
_get_reorder_map
(
self
,
chunk_info
):
def
_get_reorder_map
(
self
,
chunk_info
):
...
@@ -60,18 +60,18 @@ class ReorderGraph(object):
...
@@ -60,18 +60,18 @@ class ReorderGraph(object):
def
_reorder_idx_trace
(
self
,
reorder_map
):
def
_reorder_idx_trace
(
self
,
reorder_map
):
# reorder list
# reorder list
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_indice
.
i
dx
_trace_list
))]
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_indice
.
i
ndice
_trace_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_idx_trace_list
[
new_idx
]
=
self
.
trace_indice
.
i
dx
_trace_list
[
old_idx
]
new_idx_trace_list
[
new_idx
]
=
self
.
trace_indice
.
i
ndice
_trace_list
[
old_idx
]
self
.
trace_indice
.
i
dx
_trace_list
=
new_idx_trace_list
self
.
trace_indice
.
i
ndice
_trace_list
=
new_idx_trace_list
# update compute
# update compute
for
idx_trace
in
self
.
trace_indice
.
i
dx
_trace_list
:
for
idx_trace
in
self
.
trace_indice
.
i
ndice
_trace_list
:
compute
=
idx_trace
[
"compute"
]
compute
=
idx_trace
[
"compute"
]
for
dim_compute
in
compute
:
for
dim_compute
in
compute
:
for
idx
,
i
in
enumerate
(
dim_compute
):
for
idx
,
i
in
enumerate
(
dim_compute
):
dim_compute
[
idx
]
=
reorder_map
[
i
]
dim_compute
[
idx
]
=
reorder_map
[
i
]
# update source
# update source
for
idx_trace
in
self
.
trace_indice
.
i
dx
_trace_list
:
for
idx_trace
in
self
.
trace_indice
.
i
ndice
_trace_list
:
source
=
idx_trace
[
"source"
]
source
=
idx_trace
[
"source"
]
for
dim_idx
,
dim_source
in
enumerate
(
source
):
for
dim_idx
,
dim_source
in
enumerate
(
source
):
new_dim_source
=
{}
new_dim_source
=
{}
...
...
colossalai/autochunk/search_chunk.py
View file @
cb9817f7
...
@@ -205,7 +205,7 @@ class SearchChunk(object):
...
@@ -205,7 +205,7 @@ class SearchChunk(object):
possible_chunk_region (List)
possible_chunk_region (List)
"""
"""
possible_chunk_region
=
[]
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
trace_indice
.
i
dx
_trace_list
)
output_trace
=
copy
.
deepcopy
(
self
.
trace_indice
.
i
ndice
_trace_list
)
input_trace
=
[]
# trace of a node's input nodes
input_trace
=
[]
# trace of a node's input nodes
for
_
,
n
in
enumerate
(
self
.
trace_indice
.
node_list
):
for
_
,
n
in
enumerate
(
self
.
trace_indice
.
node_list
):
cur_trace
=
{}
cur_trace
=
{}
...
...
colossalai/autochunk/trace_flow.py
View file @
cb9817f7
...
@@ -406,7 +406,7 @@ class TraceFlow(object):
...
@@ -406,7 +406,7 @@ class TraceFlow(object):
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
node
.
args
[
1
:]
reshape_args
=
node
.
args
[
1
:]
reshape_log
=
self
.
trace_indice
.
i
dx
_view_list
[
node
]
reshape_log
=
self
.
trace_indice
.
i
ndice
_view_list
[
node
]
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
reshape_size
[
node
.
name
]
=
{}
reshape_size
[
node
.
name
]
=
{}
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
...
...
colossalai/autochunk/trace_indice.py
View file @
cb9817f7
...
@@ -9,13 +9,13 @@ from .utils import (
...
@@ -9,13 +9,13 @@ from .utils import (
class
TraceIndice
(
object
):
class
TraceIndice
(
object
):
def
__init__
(
self
,
node_list
)
->
None
:
def
__init__
(
self
,
node_list
)
->
None
:
self
.
node_list
=
node_list
self
.
node_list
=
node_list
self
.
i
dx
_trace_list
=
self
.
_init_i
dx
_trace_list
()
self
.
i
ndice
_trace_list
=
self
.
_init_i
ndice
_trace_list
()
self
.
i
dx
_trace_equal
=
[]
self
.
i
ndice
_trace_equal
=
[]
self
.
i
dx
_view_list
=
{}
self
.
i
ndice
_view_list
=
{}
self
.
i
dx
_count
=
-
1
self
.
i
ndice
_count
=
-
1
def
_init_i
dx
_trace_list
(
self
):
def
_init_i
ndice
_trace_list
(
self
):
i
dx
_trace_list
=
[]
i
ndice
_trace_list
=
[]
for
n
in
self
.
node_list
:
for
n
in
self
.
node_list
:
if
get_node_shape
(
n
)
!=
None
:
if
get_node_shape
(
n
)
!=
None
:
cur_trace
=
{
cur_trace
=
{
...
@@ -25,37 +25,37 @@ class TraceIndice(object):
...
@@ -25,37 +25,37 @@ class TraceIndice(object):
}
}
else
:
else
:
cur_trace
=
{
"idx"
:
[],
"compute"
:
[],
"source"
:
[]}
cur_trace
=
{
"idx"
:
[],
"compute"
:
[],
"source"
:
[]}
i
dx
_trace_list
.
append
(
cur_trace
)
i
ndice
_trace_list
.
append
(
cur_trace
)
return
i
dx
_trace_list
return
i
ndice
_trace_list
def
_add_inde
x
(
self
):
def
_add_ind
ic
e
(
self
):
"""
"""
Update the count and return it. To record the idx number.
Update the count and return it. To record the idx number.
Returns:
Returns:
idx_count: int
idx_count: int
"""
"""
self
.
i
dx
_count
+=
1
self
.
i
ndice
_count
+=
1
return
self
.
i
dx
_count
return
self
.
i
ndice
_count
def
_del_dim
(
self
,
idx
,
dim_idx
):
def
_del_dim
(
self
,
idx
,
dim_idx
):
self
.
i
dx
_trace_list
[
idx
][
"idx"
].
pop
(
dim_idx
)
self
.
i
ndice
_trace_list
[
idx
][
"idx"
].
pop
(
dim_idx
)
self
.
i
dx
_trace_list
[
idx
][
"compute"
].
pop
(
dim_idx
)
self
.
i
ndice
_trace_list
[
idx
][
"compute"
].
pop
(
dim_idx
)
self
.
i
dx
_trace_list
[
idx
][
"source"
].
pop
(
dim_idx
)
self
.
i
ndice
_trace_list
[
idx
][
"source"
].
pop
(
dim_idx
)
def
_add_dim
(
self
,
node_idx
,
dim_idx
):
def
_add_dim
(
self
,
node_idx
,
dim_idx
):
self
.
i
dx
_trace_list
[
node_idx
][
"idx"
].
insert
(
dim_idx
,
self
.
_add_inde
x
())
self
.
i
ndice
_trace_list
[
node_idx
][
"idx"
].
insert
(
dim_idx
,
self
.
_add_ind
ic
e
())
self
.
i
dx
_trace_list
[
node_idx
][
"compute"
].
insert
(
dim_idx
,
[])
self
.
i
ndice
_trace_list
[
node_idx
][
"compute"
].
insert
(
dim_idx
,
[])
self
.
i
dx
_trace_list
[
node_idx
][
"source"
].
insert
(
dim_idx
,
{})
self
.
i
ndice
_trace_list
[
node_idx
][
"source"
].
insert
(
dim_idx
,
{})
def
_transform_inde
x
(
self
,
node
,
node_dim
):
def
_transform_ind
ic
e
(
self
,
node
,
node_dim
):
node_idx
=
self
.
_find_i
dx
_trace_from_node
(
node
)
node_idx
=
self
.
_find_i
ndice
_trace_from_node
(
node
)
dims
=
list
(
range
(
len
(
node_idx
)))
dims
=
list
(
range
(
len
(
node_idx
)))
return
dims
[
node_dim
]
return
dims
[
node_dim
]
def
_inherit_inde
x
(
self
,
node_from
,
node_from_dim
,
node_to
,
node_to_dim
):
def
_inherit_ind
ic
e
(
self
,
node_from
,
node_from_dim
,
node_to
,
node_to_dim
):
node_from_dim
=
self
.
_transform_inde
x
(
node_from
,
node_from_dim
)
node_from_dim
=
self
.
_transform_ind
ic
e
(
node_from
,
node_from_dim
)
node_to_dim
=
self
.
_transform_inde
x
(
node_to
,
node_to_dim
)
node_to_dim
=
self
.
_transform_ind
ic
e
(
node_to
,
node_to_dim
)
node_from_trace
=
self
.
_find_trace_from_node
(
node_from
)
node_from_trace
=
self
.
_find_trace_from_node
(
node_from
)
node_to_trace
=
self
.
_find_trace_from_node
(
node_to
)
node_to_trace
=
self
.
_find_trace_from_node
(
node_to
)
node_to_trace
[
"idx"
][
node_to_dim
]
=
node_from_trace
[
"idx"
][
node_from_dim
]
node_to_trace
[
"idx"
][
node_to_dim
]
=
node_from_trace
[
"idx"
][
node_from_dim
]
...
@@ -73,9 +73,9 @@ class TraceIndice(object):
...
@@ -73,9 +73,9 @@ class TraceIndice(object):
node_to_compute
[
i
]
=
copy
.
deepcopy
(
node_from_compute
[
i
])
node_to_compute
[
i
]
=
copy
.
deepcopy
(
node_from_compute
[
i
])
def
_add_source
(
self
,
node_from
,
node_from_dim
,
node_to
,
node_to_dim
,
init
=
False
):
def
_add_source
(
self
,
node_from
,
node_from_dim
,
node_to
,
node_to_dim
,
init
=
False
):
node_from_dim
=
self
.
_transform_inde
x
(
node_from
,
node_from_dim
)
node_from_dim
=
self
.
_transform_ind
ic
e
(
node_from
,
node_from_dim
)
node_from_trace_source
=
self
.
_find_source_trace_from_node
(
node_from
)
node_from_trace_source
=
self
.
_find_source_trace_from_node
(
node_from
)
node_to_dim
=
self
.
_transform_inde
x
(
node_to
,
node_to_dim
)
node_to_dim
=
self
.
_transform_ind
ic
e
(
node_to
,
node_to_dim
)
node_to_trace_source
=
self
.
_find_source_trace_from_node
(
node_to
)
node_to_trace_source
=
self
.
_find_source_trace_from_node
(
node_to
)
node_from_idx
=
find_idx_by_name
(
node_from
.
name
,
self
.
node_list
)
node_from_idx
=
find_idx_by_name
(
node_from
.
name
,
self
.
node_list
)
if
init
:
if
init
:
...
@@ -99,19 +99,19 @@ class TraceIndice(object):
...
@@ -99,19 +99,19 @@ class TraceIndice(object):
if
exclude
==
None
:
if
exclude
==
None
:
exclude
=
[]
exclude
=
[]
else
:
else
:
exclude
=
[
self
.
_transform_inde
x
(
node_to
,
i
)
for
i
in
exclude
]
exclude
=
[
self
.
_transform_ind
ic
e
(
node_to
,
i
)
for
i
in
exclude
]
node_from_compute
=
self
.
_find_compute_trace_from_node
(
node_from
)
node_from_compute
=
self
.
_find_compute_trace_from_node
(
node_from
)
node_to_compute
=
self
.
_find_compute_trace_from_node
(
node_to
)
node_to_compute
=
self
.
_find_compute_trace_from_node
(
node_to
)
# assert len(node_from_compute) == len(node_to_compute)
# assert len(node_from_compute) == len(node_to_compute)
for
i
in
range
(
-
1
,
-
min
(
len
(
node_from_compute
),
len
(
node_to_compute
))
-
1
,
-
1
):
for
i
in
range
(
-
1
,
-
min
(
len
(
node_from_compute
),
len
(
node_to_compute
))
-
1
,
-
1
):
if
self
.
_transform_inde
x
(
node_to
,
i
)
in
exclude
:
if
self
.
_transform_ind
ic
e
(
node_to
,
i
)
in
exclude
:
continue
continue
self
.
_add_source
(
node_from
,
i
,
node_to
,
i
)
self
.
_add_source
(
node_from
,
i
,
node_to
,
i
)
for
j
in
node_from_compute
[
i
]:
for
j
in
node_from_compute
[
i
]:
if
j
not
in
node_to_compute
[
i
]:
if
j
not
in
node_to_compute
[
i
]:
node_to_compute
[
i
].
append
(
j
)
node_to_compute
[
i
].
append
(
j
)
def
_mark_i
dx
_equal
(
self
,
node1
,
dim1
,
node2
,
dim2
):
def
_mark_i
ndice
_equal
(
self
,
node1
,
dim1
,
node2
,
dim2
):
"""
"""
Mark 2 index to be equal.
Mark 2 index to be equal.
...
@@ -140,8 +140,8 @@ class TraceIndice(object):
...
@@ -140,8 +140,8 @@ class TraceIndice(object):
dims
=
list
(
range
(
len
(
get_node_shape
(
node
))))
dims
=
list
(
range
(
len
(
get_node_shape
(
node
))))
for
d
in
dim
:
for
d
in
dim
:
cur_dim
=
dims
[
d
]
cur_dim
=
dims
[
d
]
if
idx
not
in
self
.
i
dx
_trace_list
[
idx
][
"compute"
][
cur_dim
]:
if
idx
not
in
self
.
i
ndice
_trace_list
[
idx
][
"compute"
][
cur_dim
]:
self
.
i
dx
_trace_list
[
idx
][
"compute"
][
cur_dim
].
append
(
idx
)
self
.
i
ndice
_trace_list
[
idx
][
"compute"
][
cur_dim
].
append
(
idx
)
def
_find_trace_from_node
(
self
,
node
):
def
_find_trace_from_node
(
self
,
node
):
"""
"""
...
@@ -154,7 +154,7 @@ class TraceIndice(object):
...
@@ -154,7 +154,7 @@ class TraceIndice(object):
compute (list): computed idx of the node.
compute (list): computed idx of the node.
"""
"""
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_dict
=
self
.
i
dx
_trace_list
[
node_idx
]
node_dict
=
self
.
i
ndice
_trace_list
[
node_idx
]
return
node_dict
return
node_dict
def
_find_source_trace_from_node
(
self
,
node
):
def
_find_source_trace_from_node
(
self
,
node
):
...
@@ -168,10 +168,10 @@ class TraceIndice(object):
...
@@ -168,10 +168,10 @@ class TraceIndice(object):
compute (list): computed idx of the node.
compute (list): computed idx of the node.
"""
"""
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_dict
=
self
.
i
dx
_trace_list
[
node_idx
]
node_dict
=
self
.
i
ndice
_trace_list
[
node_idx
]
return
node_dict
[
"source"
]
return
node_dict
[
"source"
]
def
_find_i
dx
_trace_from_node
(
self
,
node
):
def
_find_i
ndice
_trace_from_node
(
self
,
node
):
"""
"""
Find node idx trace by the node.
Find node idx trace by the node.
...
@@ -181,7 +181,7 @@ class TraceIndice(object):
...
@@ -181,7 +181,7 @@ class TraceIndice(object):
idx (list): idx of the node
idx (list): idx of the node
"""
"""
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
return
self
.
i
dx
_trace_list
[
node_idx
][
"idx"
]
return
self
.
i
ndice
_trace_list
[
node_idx
][
"idx"
]
def
_find_compute_trace_from_node
(
self
,
node
):
def
_find_compute_trace_from_node
(
self
,
node
):
"""
"""
...
@@ -193,7 +193,7 @@ class TraceIndice(object):
...
@@ -193,7 +193,7 @@ class TraceIndice(object):
compute (list): computed idx of the node.
compute (list): computed idx of the node.
"""
"""
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
return
self
.
i
dx
_trace_list
[
node_idx
][
"compute"
]
return
self
.
i
ndice
_trace_list
[
node_idx
][
"compute"
]
def
_assign_index_as_input
(
self
,
node
,
node_idx
,
input_node
=
None
):
def
_assign_index_as_input
(
self
,
node
,
node_idx
,
input_node
=
None
):
"""
"""
...
@@ -206,14 +206,14 @@ class TraceIndice(object):
...
@@ -206,14 +206,14 @@ class TraceIndice(object):
if
input_node
==
None
:
if
input_node
==
None
:
input_node
=
node
.
args
[
0
]
input_node
=
node
.
args
[
0
]
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
input_node_idx_trace
=
self
.
i
dx
_trace_list
[
input_node_idx
][
"idx"
]
input_node_idx_trace
=
self
.
i
ndice
_trace_list
[
input_node_idx
][
"idx"
]
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
self
.
i
dx
_trace_list
[
node_idx
][
"idx"
]
=
new_idx_trace
self
.
i
ndice
_trace_list
[
node_idx
][
"idx"
]
=
new_idx_trace
self
.
_inherit_all_computation
(
input_node
,
node
)
self
.
_inherit_all_computation
(
input_node
,
node
)
def
_assign_all_inde
x
(
self
,
node
,
node_idx
):
def
_assign_all_ind
ic
e
(
self
,
node
,
node_idx
):
"""
"""
Add new index for all node's dims.
Add new index for all node's dims.
...
@@ -224,10 +224,10 @@ class TraceIndice(object):
...
@@ -224,10 +224,10 @@ class TraceIndice(object):
shape
=
node
.
meta
[
"tensor_meta"
].
shape
shape
=
node
.
meta
[
"tensor_meta"
].
shape
new_trace
=
[]
new_trace
=
[]
for
_
in
shape
:
for
_
in
shape
:
new_trace
.
append
(
self
.
_add_inde
x
())
new_trace
.
append
(
self
.
_add_ind
ic
e
())
self
.
i
dx
_trace_list
[
node_idx
][
"idx"
]
=
new_trace
self
.
i
ndice
_trace_list
[
node_idx
][
"idx"
]
=
new_trace
def
_assign_transpose_inde
x
(
self
,
node
,
node_idx
):
def
_assign_transpose_ind
ic
e
(
self
,
node
,
node_idx
):
"""
"""
Assign index for transpose op.
Assign index for transpose op.
1. swap input's dim according to transpose args
1. swap input's dim according to transpose args
...
@@ -241,10 +241,10 @@ class TraceIndice(object):
...
@@ -241,10 +241,10 @@ class TraceIndice(object):
tranpose_dim
=
node
.
args
[
1
:]
tranpose_dim
=
node
.
args
[
1
:]
self
.
_assign_index_as_input
(
node
,
node_idx
,
input_node
)
self
.
_assign_index_as_input
(
node
,
node_idx
,
input_node
)
self
.
_inherit_inde
x
(
input_node
,
tranpose_dim
[
1
],
node
,
tranpose_dim
[
0
])
self
.
_inherit_ind
ic
e
(
input_node
,
tranpose_dim
[
1
],
node
,
tranpose_dim
[
0
])
self
.
_inherit_inde
x
(
input_node
,
tranpose_dim
[
0
],
node
,
tranpose_dim
[
1
])
self
.
_inherit_ind
ic
e
(
input_node
,
tranpose_dim
[
0
],
node
,
tranpose_dim
[
1
])
def
_assign_permute_inde
x
(
self
,
node
,
node_idx
):
def
_assign_permute_ind
ic
e
(
self
,
node
,
node_idx
):
"""
"""
Assign index for permute op.
Assign index for permute op.
1. swap input's dim according to permute args
1. swap input's dim according to permute args
...
@@ -259,9 +259,9 @@ class TraceIndice(object):
...
@@ -259,9 +259,9 @@ class TraceIndice(object):
self
.
_assign_index_as_input
(
node
,
node_idx
,
input_node
)
self
.
_assign_index_as_input
(
node
,
node_idx
,
input_node
)
for
idx
,
d
in
enumerate
(
permute_dim
):
for
idx
,
d
in
enumerate
(
permute_dim
):
self
.
_inherit_inde
x
(
input_node
,
d
,
node
,
idx
)
self
.
_inherit_ind
ic
e
(
input_node
,
d
,
node
,
idx
)
def
_assign_linear_inde
x
(
self
,
node
,
node_idx
):
def
_assign_linear_ind
ic
e
(
self
,
node
,
node_idx
):
"""
"""
Assign index for linear op.
Assign index for linear op.
1. copy trace from input node and change last index accroding to weight
1. copy trace from input node and change last index accroding to weight
...
@@ -279,15 +279,15 @@ class TraceIndice(object):
...
@@ -279,15 +279,15 @@ class TraceIndice(object):
input_node
,
weight
,
bias
=
node
.
args
input_node
,
weight
,
bias
=
node
.
args
self
.
_assign_index_as_input
(
node
,
node_idx
)
self
.
_assign_index_as_input
(
node
,
node_idx
)
self
.
_inherit_inde
x
(
weight
,
1
,
node
,
-
1
)
self
.
_inherit_ind
ic
e
(
weight
,
1
,
node
,
-
1
)
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
_mark_i
dx
_equal
(
input_node
,
-
1
,
weight
,
0
)
self
.
_mark_i
ndice
_equal
(
input_node
,
-
1
,
weight
,
0
)
if
bias
:
if
bias
:
self
.
_mark_i
dx
_equal
(
input_node
,
-
1
,
bias
,
0
)
self
.
_mark_i
ndice
_equal
(
input_node
,
-
1
,
bias
,
0
)
def
_assign_matmul_inde
x
(
self
,
node
,
node_idx
):
def
_assign_matmul_ind
ic
e
(
self
,
node
,
node_idx
):
"""
"""
Assign index for matmul op.
Assign index for matmul op.
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
...
@@ -302,13 +302,13 @@ class TraceIndice(object):
...
@@ -302,13 +302,13 @@ class TraceIndice(object):
assert
len
(
get_node_shape
(
matmul_left
))
==
len
(
get_node_shape
(
matmul_right
))
assert
len
(
get_node_shape
(
matmul_left
))
==
len
(
get_node_shape
(
matmul_right
))
self
.
_assign_index_as_input
(
node
,
node_idx
,
matmul_left
)
self
.
_assign_index_as_input
(
node
,
node_idx
,
matmul_left
)
self
.
_inherit_inde
x
(
matmul_right
,
-
1
,
node
,
-
1
)
self
.
_inherit_ind
ic
e
(
matmul_right
,
-
1
,
node
,
-
1
)
self
.
_mark_computation_from_node
(
matmul_right
,
node
,
[
-
1
,
-
2
])
self
.
_mark_computation_from_node
(
matmul_right
,
node
,
[
-
1
,
-
2
])
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
_mark_i
dx
_equal
(
matmul_left
,
-
1
,
matmul_right
,
-
2
)
self
.
_mark_i
ndice
_equal
(
matmul_left
,
-
1
,
matmul_right
,
-
2
)
def
_assign_layernorm_inde
x
(
self
,
node
,
idx
):
def
_assign_layernorm_ind
ic
e
(
self
,
node
,
idx
):
"""
"""
Assign index for layernorm op.
Assign index for layernorm op.
1. assign index as input node
1. assign index as input node
...
@@ -321,7 +321,7 @@ class TraceIndice(object):
...
@@ -321,7 +321,7 @@ class TraceIndice(object):
self
.
_assign_index_as_input
(
node
,
idx
)
self
.
_assign_index_as_input
(
node
,
idx
)
self
.
_mark_computation
(
node
,
idx
,
[
-
1
])
self
.
_mark_computation
(
node
,
idx
,
[
-
1
])
def
_assign_elementwise_inde
x
(
self
,
node
,
idx
):
def
_assign_elementwise_ind
ic
e
(
self
,
node
,
idx
):
"""
"""
Assign index for element-wise op (eg. relu sigmoid add mul).
Assign index for element-wise op (eg. relu sigmoid add mul).
1. assign index as input node
1. assign index as input node
...
@@ -343,15 +343,15 @@ class TraceIndice(object):
...
@@ -343,15 +343,15 @@ class TraceIndice(object):
node_in1_shape
=
get_node_shape
(
nodes_in
[
1
])
node_in1_shape
=
get_node_shape
(
nodes_in
[
1
])
for
i
in
range
(
-
1
,
-
min
(
len
(
node_in0_shape
),
len
(
node_in1_shape
))
-
1
,
-
1
):
for
i
in
range
(
-
1
,
-
min
(
len
(
node_in0_shape
),
len
(
node_in1_shape
))
-
1
,
-
1
):
if
node_in0_shape
[
i
]
==
node_in1_shape
[
i
]:
if
node_in0_shape
[
i
]
==
node_in1_shape
[
i
]:
self
.
_mark_i
dx
_equal
(
nodes_in
[
0
],
i
,
nodes_in
[
1
],
i
)
self
.
_mark_i
ndice
_equal
(
nodes_in
[
0
],
i
,
nodes_in
[
1
],
i
)
def
_assgin_no_change_inde
x
(
self
,
node
,
idx
):
def
_assgin_no_change_ind
ic
e
(
self
,
node
,
idx
):
self
.
_assign_index_as_input
(
node
,
idx
)
self
.
_assign_index_as_input
(
node
,
idx
)
for
node_in
in
node
.
args
:
for
node_in
in
node
.
args
:
if
type
(
node_in
)
==
type
(
node
):
if
type
(
node_in
)
==
type
(
node
):
self
.
_mark_computation_from_node
(
node_in
,
node
)
self
.
_mark_computation_from_node
(
node_in
,
node
)
def
_assign_einsum_inde
x
(
self
,
node
,
idx
):
def
_assign_einsum_ind
ic
e
(
self
,
node
,
idx
):
"""
"""
Assign index for einsum op.
Assign index for einsum op.
...
@@ -378,7 +378,7 @@ class TraceIndice(object):
...
@@ -378,7 +378,7 @@ class TraceIndice(object):
for
left_idx
,
left_str
in
enumerate
(
left
):
for
left_idx
,
left_str
in
enumerate
(
left
):
if
right_indice
in
left_str
:
if
right_indice
in
left_str
:
source_idx
=
left_str
.
index
(
right_indice
)
source_idx
=
left_str
.
index
(
right_indice
)
self
.
_inherit_inde
x
(
self
.
_inherit_ind
ic
e
(
input_nodes
[
left_idx
],
source_idx
,
node
,
right_idx
input_nodes
[
left_idx
],
source_idx
,
node
,
right_idx
)
)
...
@@ -388,7 +388,7 @@ class TraceIndice(object):
...
@@ -388,7 +388,7 @@ class TraceIndice(object):
# self._mark_computation(node, idx, left_str.index(i))
# self._mark_computation(node, idx, left_str.index(i))
# break
# break
def
_assign_softmax_inde
x
(
self
,
node
,
idx
):
def
_assign_softmax_ind
ic
e
(
self
,
node
,
idx
):
"""
"""
Assign index for softmax op.
Assign index for softmax op.
1. assign index as input node
1. assign index as input node
...
@@ -401,7 +401,7 @@ class TraceIndice(object):
...
@@ -401,7 +401,7 @@ class TraceIndice(object):
self
.
_assign_index_as_input
(
node
,
idx
)
self
.
_assign_index_as_input
(
node
,
idx
)
self
.
_mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
"dim"
]])
self
.
_mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
"dim"
]])
def
_assign_unsqueeze_inde
x
(
self
,
node
,
node_idx
):
def
_assign_unsqueeze_ind
ic
e
(
self
,
node
,
node_idx
):
"""
"""
Assign index for unsqueeze op.
Assign index for unsqueeze op.
1. assign new index for unsqueeze dim
1. assign new index for unsqueeze dim
...
@@ -414,7 +414,7 @@ class TraceIndice(object):
...
@@ -414,7 +414,7 @@ class TraceIndice(object):
self
.
_assign_index_as_input
(
node
,
node_idx
)
self
.
_assign_index_as_input
(
node
,
node_idx
)
self
.
_add_dim
(
node_idx
,
node
.
args
[
1
])
self
.
_add_dim
(
node_idx
,
node
.
args
[
1
])
def
_assign_dropout_inde
x
(
self
,
node
,
node_idx
):
def
_assign_dropout_ind
ic
e
(
self
,
node
,
node_idx
):
"""
"""
Assign index for unsqueeze op.
Assign index for unsqueeze op.
1. assign new index for unsqueeze dim
1. assign new index for unsqueeze dim
...
@@ -425,7 +425,7 @@ class TraceIndice(object):
...
@@ -425,7 +425,7 @@ class TraceIndice(object):
"""
"""
self
.
_assign_index_as_input
(
node
,
node_idx
)
self
.
_assign_index_as_input
(
node
,
node_idx
)
def
_assign_ones_like_inde
x
(
self
,
node
,
node_idx
):
def
_assign_ones_like_ind
ic
e
(
self
,
node
,
node_idx
):
"""
"""
Assign index for oneslike op.
Assign index for oneslike op.
1. assign new index for all dim
1. assign new index for all dim
...
@@ -434,9 +434,9 @@ class TraceIndice(object):
...
@@ -434,9 +434,9 @@ class TraceIndice(object):
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
self
.
_assign_all_inde
x
(
node
,
node_idx
)
self
.
_assign_all_ind
ic
e
(
node
,
node_idx
)
def
_assign_view_reshape_inde
x
(
self
,
node
,
node_idx
):
def
_assign_view_reshape_ind
ic
e
(
self
,
node
,
node_idx
):
"""
"""
Assign index for view and reshape op.
Assign index for view and reshape op.
1. get origin shape and target shape by meta info.
1. get origin shape and target shape by meta info.
...
@@ -496,7 +496,7 @@ class TraceIndice(object):
...
@@ -496,7 +496,7 @@ class TraceIndice(object):
)
)
# get new index
# get new index
origin_trace
=
self
.
_find_i
dx
_trace_from_node
(
origin_node
)
origin_trace
=
self
.
_find_i
ndice
_trace_from_node
(
origin_node
)
self
.
_assign_index_as_input
(
node
,
node_idx
,
origin_node
)
self
.
_assign_index_as_input
(
node
,
node_idx
,
origin_node
)
dim_from
.
reverse
()
dim_from
.
reverse
()
for
i
in
dim_from
:
for
i
in
dim_from
:
...
@@ -516,18 +516,18 @@ class TraceIndice(object):
...
@@ -516,18 +516,18 @@ class TraceIndice(object):
view_dict
=
{
view_dict
=
{
"idx_from"
:
[
origin_trace
[
i
]
for
i
in
dim_from
],
"idx_from"
:
[
origin_trace
[
i
]
for
i
in
dim_from
],
"dim_from"
:
dim_from
,
"dim_from"
:
dim_from
,
"idx_to"
:
[
self
.
i
dx
_trace_list
[
node_idx
][
"idx"
][
i
]
for
i
in
dim_to
],
"idx_to"
:
[
self
.
i
ndice
_trace_list
[
node_idx
][
"idx"
][
i
]
for
i
in
dim_to
],
"dim_to"
:
dim_to
,
"dim_to"
:
dim_to
,
}
}
self
.
i
dx
_view_list
[
node
]
=
view_dict
self
.
i
ndice
_view_list
[
node
]
=
view_dict
def
_merge_equal_idx
(
self
):
def
_merge_equal_idx
(
self
):
idx_equal
=
copy
.
deepcopy
(
self
.
i
dx
_trace_equal
)
idx_equal
=
copy
.
deepcopy
(
self
.
i
ndice
_trace_equal
)
idx_equal
.
reverse
()
idx_equal
.
reverse
()
for
idx
in
idx_equal
:
for
idx
in
idx_equal
:
merge_to
=
min
(
idx
)
merge_to
=
min
(
idx
)
merge_from
=
max
(
idx
)
merge_from
=
max
(
idx
)
for
trace
in
self
.
i
dx
_trace_list
:
for
trace
in
self
.
i
ndice
_trace_list
:
if
merge_from
in
trace
[
"idx"
]:
if
merge_from
in
trace
[
"idx"
]:
trace
[
"idx"
]
=
[
trace
[
"idx"
]
=
[
merge_to
if
i
==
merge_from
else
i
for
i
in
trace
[
"idx"
]
merge_to
if
i
==
merge_from
else
i
for
i
in
trace
[
"idx"
]
...
@@ -536,35 +536,35 @@ class TraceIndice(object):
...
@@ -536,35 +536,35 @@ class TraceIndice(object):
def
trace_index
(
self
):
def
trace_index
(
self
):
for
idx
,
node
in
enumerate
(
self
.
node_list
):
for
idx
,
node
in
enumerate
(
self
.
node_list
):
if
node
.
op
==
"placeholder"
:
if
node
.
op
==
"placeholder"
:
self
.
_assign_all_inde
x
(
node
,
idx
)
self
.
_assign_all_ind
ic
e
(
node
,
idx
)
elif
node
.
op
==
"call_method"
:
elif
node
.
op
==
"call_method"
:
if
"transpose"
in
node
.
name
:
if
"transpose"
in
node
.
name
:
self
.
_assign_transpose_inde
x
(
node
,
idx
)
self
.
_assign_transpose_ind
ic
e
(
node
,
idx
)
elif
"permute"
in
node
.
name
:
elif
"permute"
in
node
.
name
:
self
.
_assign_permute_inde
x
(
node
,
idx
)
self
.
_assign_permute_ind
ic
e
(
node
,
idx
)
elif
"view"
in
node
.
name
or
"reshape"
in
node
.
name
:
elif
"view"
in
node
.
name
or
"reshape"
in
node
.
name
:
self
.
_assign_view_reshape_inde
x
(
node
,
idx
)
self
.
_assign_view_reshape_ind
ic
e
(
node
,
idx
)
elif
"unsqueeze"
in
node
.
name
:
elif
"unsqueeze"
in
node
.
name
:
self
.
_assign_unsqueeze_inde
x
(
node
,
idx
)
self
.
_assign_unsqueeze_ind
ic
e
(
node
,
idx
)
elif
any
(
i
in
node
.
name
for
i
in
[
"to"
,
"contiguous"
]):
elif
any
(
i
in
node
.
name
for
i
in
[
"to"
,
"contiguous"
]):
self
.
_assgin_no_change_inde
x
(
node
,
idx
)
self
.
_assgin_no_change_ind
ic
e
(
node
,
idx
)
else
:
else
:
raise
NotImplementedError
(
node
.
name
,
"method not implemented yet!"
)
raise
NotImplementedError
(
node
.
name
,
"method not implemented yet!"
)
elif
node
.
op
==
"call_function"
:
elif
node
.
op
==
"call_function"
:
if
"linear"
in
node
.
name
:
if
"linear"
in
node
.
name
:
self
.
_assign_linear_inde
x
(
node
,
idx
)
self
.
_assign_linear_ind
ic
e
(
node
,
idx
)
elif
"matmul"
in
node
.
name
:
elif
"matmul"
in
node
.
name
:
self
.
_assign_matmul_inde
x
(
node
,
idx
)
self
.
_assign_matmul_ind
ic
e
(
node
,
idx
)
elif
"softmax"
in
node
.
name
:
elif
"softmax"
in
node
.
name
:
self
.
_assign_softmax_inde
x
(
node
,
idx
)
self
.
_assign_softmax_ind
ic
e
(
node
,
idx
)
elif
any
(
n
in
node
.
name
for
n
in
[
"mul"
,
"add"
,
"sigmoid"
,
"relu"
]):
elif
any
(
n
in
node
.
name
for
n
in
[
"mul"
,
"add"
,
"sigmoid"
,
"relu"
]):
self
.
_assign_elementwise_inde
x
(
node
,
idx
)
self
.
_assign_elementwise_ind
ic
e
(
node
,
idx
)
elif
"ones_like"
in
node
.
name
:
elif
"ones_like"
in
node
.
name
:
self
.
_assign_ones_like_inde
x
(
node
,
idx
)
self
.
_assign_ones_like_ind
ic
e
(
node
,
idx
)
elif
"dropout"
in
node
.
name
:
elif
"dropout"
in
node
.
name
:
self
.
_assign_dropout_inde
x
(
node
,
idx
)
self
.
_assign_dropout_ind
ic
e
(
node
,
idx
)
elif
"einsum"
in
node
.
name
:
elif
"einsum"
in
node
.
name
:
self
.
_assign_einsum_inde
x
(
node
,
idx
)
self
.
_assign_einsum_ind
ic
e
(
node
,
idx
)
elif
"getattr"
in
node
.
name
:
elif
"getattr"
in
node
.
name
:
continue
# get attr like shape
continue
# get attr like shape
elif
"getitem"
in
node
.
name
:
elif
"getitem"
in
node
.
name
:
...
@@ -575,11 +575,11 @@ class TraceIndice(object):
...
@@ -575,11 +575,11 @@ class TraceIndice(object):
)
)
elif
node
.
op
==
"call_module"
:
elif
node
.
op
==
"call_module"
:
if
any
(
n
in
node
.
name
for
n
in
[
"layernorm"
,
"norm"
]):
if
any
(
n
in
node
.
name
for
n
in
[
"layernorm"
,
"norm"
]):
self
.
_assign_layernorm_inde
x
(
node
,
idx
)
self
.
_assign_layernorm_ind
ic
e
(
node
,
idx
)
else
:
else
:
raise
NotImplementedError
(
node
.
name
,
"module not implemented yet!"
)
raise
NotImplementedError
(
node
.
name
,
"module not implemented yet!"
)
elif
node
.
op
==
"get_attr"
:
elif
node
.
op
==
"get_attr"
:
self
.
_assign_all_inde
x
(
node
,
idx
)
# get param
self
.
_assign_all_ind
ic
e
(
node
,
idx
)
# get param
elif
node
.
op
==
"output"
:
elif
node
.
op
==
"output"
:
continue
continue
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment