Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cb9817f7
Commit
cb9817f7
authored
Jan 09, 2023
by
oahzxl
Browse files
rename function from index to indice
parent
0ea903b9
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
91 additions
and
91 deletions
+91
-91
colossalai/autochunk/reorder_graph.py
colossalai/autochunk/reorder_graph.py
+6
-6
colossalai/autochunk/search_chunk.py
colossalai/autochunk/search_chunk.py
+1
-1
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+1
-1
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+83
-83
No files found.
colossalai/autochunk/reorder_graph.py
View file @
cb9817f7
...
...
@@ -6,7 +6,7 @@ class ReorderGraph(object):
def
__init__
(
self
,
trace_indice
:
TraceIndice
)
->
None
:
self
.
trace_indice
=
trace_indice
self
.
all_reorder_map
=
{
i
:
i
for
i
in
range
(
len
(
self
.
trace_indice
.
i
dx
_trace_list
))
i
:
i
for
i
in
range
(
len
(
self
.
trace_indice
.
i
ndice
_trace_list
))
}
def
_get_reorder_map
(
self
,
chunk_info
):
...
...
@@ -60,18 +60,18 @@ class ReorderGraph(object):
def
_reorder_idx_trace
(
self
,
reorder_map
):
# reorder list
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_indice
.
i
dx
_trace_list
))]
new_idx_trace_list
=
[
None
for
_
in
range
(
len
(
self
.
trace_indice
.
i
ndice
_trace_list
))]
for
old_idx
,
new_idx
in
reorder_map
.
items
():
new_idx_trace_list
[
new_idx
]
=
self
.
trace_indice
.
i
dx
_trace_list
[
old_idx
]
self
.
trace_indice
.
i
dx
_trace_list
=
new_idx_trace_list
new_idx_trace_list
[
new_idx
]
=
self
.
trace_indice
.
i
ndice
_trace_list
[
old_idx
]
self
.
trace_indice
.
i
ndice
_trace_list
=
new_idx_trace_list
# update compute
for
idx_trace
in
self
.
trace_indice
.
i
dx
_trace_list
:
for
idx_trace
in
self
.
trace_indice
.
i
ndice
_trace_list
:
compute
=
idx_trace
[
"compute"
]
for
dim_compute
in
compute
:
for
idx
,
i
in
enumerate
(
dim_compute
):
dim_compute
[
idx
]
=
reorder_map
[
i
]
# update source
for
idx_trace
in
self
.
trace_indice
.
i
dx
_trace_list
:
for
idx_trace
in
self
.
trace_indice
.
i
ndice
_trace_list
:
source
=
idx_trace
[
"source"
]
for
dim_idx
,
dim_source
in
enumerate
(
source
):
new_dim_source
=
{}
...
...
colossalai/autochunk/search_chunk.py
View file @
cb9817f7
...
...
@@ -205,7 +205,7 @@ class SearchChunk(object):
possible_chunk_region (List)
"""
possible_chunk_region
=
[]
output_trace
=
copy
.
deepcopy
(
self
.
trace_indice
.
i
dx
_trace_list
)
output_trace
=
copy
.
deepcopy
(
self
.
trace_indice
.
i
ndice
_trace_list
)
input_trace
=
[]
# trace of a node's input nodes
for
_
,
n
in
enumerate
(
self
.
trace_indice
.
node_list
):
cur_trace
=
{}
...
...
colossalai/autochunk/trace_flow.py
View file @
cb9817f7
...
...
@@ -406,7 +406,7 @@ class TraceFlow(object):
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
node
.
args
[
1
:]
reshape_log
=
self
.
trace_indice
.
i
dx
_view_list
[
node
]
reshape_log
=
self
.
trace_indice
.
i
ndice
_view_list
[
node
]
chunk_dim
=
chunk_info
[
"node_chunk_dim"
][
node
][
"chunk_dim"
]
reshape_size
[
node
.
name
]
=
{}
for
reshape_arg_dim
,
reshape_arg
in
enumerate
(
reshape_args
):
...
...
colossalai/autochunk/trace_indice.py
View file @
cb9817f7
...
...
@@ -9,13 +9,13 @@ from .utils import (
class
TraceIndice
(
object
):
def
__init__
(
self
,
node_list
)
->
None
:
self
.
node_list
=
node_list
self
.
i
dx
_trace_list
=
self
.
_init_i
dx
_trace_list
()
self
.
i
dx
_trace_equal
=
[]
self
.
i
dx
_view_list
=
{}
self
.
i
dx
_count
=
-
1
self
.
i
ndice
_trace_list
=
self
.
_init_i
ndice
_trace_list
()
self
.
i
ndice
_trace_equal
=
[]
self
.
i
ndice
_view_list
=
{}
self
.
i
ndice
_count
=
-
1
def
_init_i
dx
_trace_list
(
self
):
i
dx
_trace_list
=
[]
def
_init_i
ndice
_trace_list
(
self
):
i
ndice
_trace_list
=
[]
for
n
in
self
.
node_list
:
if
get_node_shape
(
n
)
!=
None
:
cur_trace
=
{
...
...
@@ -25,37 +25,37 @@ class TraceIndice(object):
}
else
:
cur_trace
=
{
"idx"
:
[],
"compute"
:
[],
"source"
:
[]}
i
dx
_trace_list
.
append
(
cur_trace
)
return
i
dx
_trace_list
i
ndice
_trace_list
.
append
(
cur_trace
)
return
i
ndice
_trace_list
def
_add_inde
x
(
self
):
def
_add_ind
ic
e
(
self
):
"""
Update the count and return it. To record the idx number.
Returns:
idx_count: int
"""
self
.
i
dx
_count
+=
1
return
self
.
i
dx
_count
self
.
i
ndice
_count
+=
1
return
self
.
i
ndice
_count
def
_del_dim
(
self
,
idx
,
dim_idx
):
self
.
i
dx
_trace_list
[
idx
][
"idx"
].
pop
(
dim_idx
)
self
.
i
dx
_trace_list
[
idx
][
"compute"
].
pop
(
dim_idx
)
self
.
i
dx
_trace_list
[
idx
][
"source"
].
pop
(
dim_idx
)
self
.
i
ndice
_trace_list
[
idx
][
"idx"
].
pop
(
dim_idx
)
self
.
i
ndice
_trace_list
[
idx
][
"compute"
].
pop
(
dim_idx
)
self
.
i
ndice
_trace_list
[
idx
][
"source"
].
pop
(
dim_idx
)
def
_add_dim
(
self
,
node_idx
,
dim_idx
):
self
.
i
dx
_trace_list
[
node_idx
][
"idx"
].
insert
(
dim_idx
,
self
.
_add_inde
x
())
self
.
i
dx
_trace_list
[
node_idx
][
"compute"
].
insert
(
dim_idx
,
[])
self
.
i
dx
_trace_list
[
node_idx
][
"source"
].
insert
(
dim_idx
,
{})
self
.
i
ndice
_trace_list
[
node_idx
][
"idx"
].
insert
(
dim_idx
,
self
.
_add_ind
ic
e
())
self
.
i
ndice
_trace_list
[
node_idx
][
"compute"
].
insert
(
dim_idx
,
[])
self
.
i
ndice
_trace_list
[
node_idx
][
"source"
].
insert
(
dim_idx
,
{})
def
_transform_inde
x
(
self
,
node
,
node_dim
):
node_idx
=
self
.
_find_i
dx
_trace_from_node
(
node
)
def
_transform_ind
ic
e
(
self
,
node
,
node_dim
):
node_idx
=
self
.
_find_i
ndice
_trace_from_node
(
node
)
dims
=
list
(
range
(
len
(
node_idx
)))
return
dims
[
node_dim
]
def
_inherit_inde
x
(
self
,
node_from
,
node_from_dim
,
node_to
,
node_to_dim
):
node_from_dim
=
self
.
_transform_inde
x
(
node_from
,
node_from_dim
)
node_to_dim
=
self
.
_transform_inde
x
(
node_to
,
node_to_dim
)
def
_inherit_ind
ic
e
(
self
,
node_from
,
node_from_dim
,
node_to
,
node_to_dim
):
node_from_dim
=
self
.
_transform_ind
ic
e
(
node_from
,
node_from_dim
)
node_to_dim
=
self
.
_transform_ind
ic
e
(
node_to
,
node_to_dim
)
node_from_trace
=
self
.
_find_trace_from_node
(
node_from
)
node_to_trace
=
self
.
_find_trace_from_node
(
node_to
)
node_to_trace
[
"idx"
][
node_to_dim
]
=
node_from_trace
[
"idx"
][
node_from_dim
]
...
...
@@ -73,9 +73,9 @@ class TraceIndice(object):
node_to_compute
[
i
]
=
copy
.
deepcopy
(
node_from_compute
[
i
])
def
_add_source
(
self
,
node_from
,
node_from_dim
,
node_to
,
node_to_dim
,
init
=
False
):
node_from_dim
=
self
.
_transform_inde
x
(
node_from
,
node_from_dim
)
node_from_dim
=
self
.
_transform_ind
ic
e
(
node_from
,
node_from_dim
)
node_from_trace_source
=
self
.
_find_source_trace_from_node
(
node_from
)
node_to_dim
=
self
.
_transform_inde
x
(
node_to
,
node_to_dim
)
node_to_dim
=
self
.
_transform_ind
ic
e
(
node_to
,
node_to_dim
)
node_to_trace_source
=
self
.
_find_source_trace_from_node
(
node_to
)
node_from_idx
=
find_idx_by_name
(
node_from
.
name
,
self
.
node_list
)
if
init
:
...
...
@@ -99,19 +99,19 @@ class TraceIndice(object):
if
exclude
==
None
:
exclude
=
[]
else
:
exclude
=
[
self
.
_transform_inde
x
(
node_to
,
i
)
for
i
in
exclude
]
exclude
=
[
self
.
_transform_ind
ic
e
(
node_to
,
i
)
for
i
in
exclude
]
node_from_compute
=
self
.
_find_compute_trace_from_node
(
node_from
)
node_to_compute
=
self
.
_find_compute_trace_from_node
(
node_to
)
# assert len(node_from_compute) == len(node_to_compute)
for
i
in
range
(
-
1
,
-
min
(
len
(
node_from_compute
),
len
(
node_to_compute
))
-
1
,
-
1
):
if
self
.
_transform_inde
x
(
node_to
,
i
)
in
exclude
:
if
self
.
_transform_ind
ic
e
(
node_to
,
i
)
in
exclude
:
continue
self
.
_add_source
(
node_from
,
i
,
node_to
,
i
)
for
j
in
node_from_compute
[
i
]:
if
j
not
in
node_to_compute
[
i
]:
node_to_compute
[
i
].
append
(
j
)
def
_mark_i
dx
_equal
(
self
,
node1
,
dim1
,
node2
,
dim2
):
def
_mark_i
ndice
_equal
(
self
,
node1
,
dim1
,
node2
,
dim2
):
"""
Mark 2 index to be equal.
...
...
@@ -140,8 +140,8 @@ class TraceIndice(object):
dims
=
list
(
range
(
len
(
get_node_shape
(
node
))))
for
d
in
dim
:
cur_dim
=
dims
[
d
]
if
idx
not
in
self
.
i
dx
_trace_list
[
idx
][
"compute"
][
cur_dim
]:
self
.
i
dx
_trace_list
[
idx
][
"compute"
][
cur_dim
].
append
(
idx
)
if
idx
not
in
self
.
i
ndice
_trace_list
[
idx
][
"compute"
][
cur_dim
]:
self
.
i
ndice
_trace_list
[
idx
][
"compute"
][
cur_dim
].
append
(
idx
)
def
_find_trace_from_node
(
self
,
node
):
"""
...
...
@@ -154,7 +154,7 @@ class TraceIndice(object):
compute (list): computed idx of the node.
"""
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_dict
=
self
.
i
dx
_trace_list
[
node_idx
]
node_dict
=
self
.
i
ndice
_trace_list
[
node_idx
]
return
node_dict
def
_find_source_trace_from_node
(
self
,
node
):
...
...
@@ -168,10 +168,10 @@ class TraceIndice(object):
compute (list): computed idx of the node.
"""
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
node_dict
=
self
.
i
dx
_trace_list
[
node_idx
]
node_dict
=
self
.
i
ndice
_trace_list
[
node_idx
]
return
node_dict
[
"source"
]
def
_find_i
dx
_trace_from_node
(
self
,
node
):
def
_find_i
ndice
_trace_from_node
(
self
,
node
):
"""
Find node idx trace by the node.
...
...
@@ -181,7 +181,7 @@ class TraceIndice(object):
idx (list): idx of the node
"""
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
return
self
.
i
dx
_trace_list
[
node_idx
][
"idx"
]
return
self
.
i
ndice
_trace_list
[
node_idx
][
"idx"
]
def
_find_compute_trace_from_node
(
self
,
node
):
"""
...
...
@@ -193,7 +193,7 @@ class TraceIndice(object):
compute (list): computed idx of the node.
"""
node_idx
=
find_idx_by_name
(
node
.
name
,
self
.
node_list
)
return
self
.
i
dx
_trace_list
[
node_idx
][
"compute"
]
return
self
.
i
ndice
_trace_list
[
node_idx
][
"compute"
]
def
_assign_index_as_input
(
self
,
node
,
node_idx
,
input_node
=
None
):
"""
...
...
@@ -206,14 +206,14 @@ class TraceIndice(object):
if
input_node
==
None
:
input_node
=
node
.
args
[
0
]
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
input_node_idx_trace
=
self
.
i
dx
_trace_list
[
input_node_idx
][
"idx"
]
input_node_idx_trace
=
self
.
i
ndice
_trace_list
[
input_node_idx
][
"idx"
]
new_idx_trace
=
copy
.
deepcopy
(
input_node_idx_trace
)
self
.
i
dx
_trace_list
[
node_idx
][
"idx"
]
=
new_idx_trace
self
.
i
ndice
_trace_list
[
node_idx
][
"idx"
]
=
new_idx_trace
self
.
_inherit_all_computation
(
input_node
,
node
)
def
_assign_all_inde
x
(
self
,
node
,
node_idx
):
def
_assign_all_ind
ic
e
(
self
,
node
,
node_idx
):
"""
Add new index for all node's dims.
...
...
@@ -224,10 +224,10 @@ class TraceIndice(object):
shape
=
node
.
meta
[
"tensor_meta"
].
shape
new_trace
=
[]
for
_
in
shape
:
new_trace
.
append
(
self
.
_add_inde
x
())
self
.
i
dx
_trace_list
[
node_idx
][
"idx"
]
=
new_trace
new_trace
.
append
(
self
.
_add_ind
ic
e
())
self
.
i
ndice
_trace_list
[
node_idx
][
"idx"
]
=
new_trace
def
_assign_transpose_inde
x
(
self
,
node
,
node_idx
):
def
_assign_transpose_ind
ic
e
(
self
,
node
,
node_idx
):
"""
Assign index for transpose op.
1. swap input's dim according to transpose args
...
...
@@ -241,10 +241,10 @@ class TraceIndice(object):
tranpose_dim
=
node
.
args
[
1
:]
self
.
_assign_index_as_input
(
node
,
node_idx
,
input_node
)
self
.
_inherit_inde
x
(
input_node
,
tranpose_dim
[
1
],
node
,
tranpose_dim
[
0
])
self
.
_inherit_inde
x
(
input_node
,
tranpose_dim
[
0
],
node
,
tranpose_dim
[
1
])
self
.
_inherit_ind
ic
e
(
input_node
,
tranpose_dim
[
1
],
node
,
tranpose_dim
[
0
])
self
.
_inherit_ind
ic
e
(
input_node
,
tranpose_dim
[
0
],
node
,
tranpose_dim
[
1
])
def
_assign_permute_inde
x
(
self
,
node
,
node_idx
):
def
_assign_permute_ind
ic
e
(
self
,
node
,
node_idx
):
"""
Assign index for permute op.
1. swap input's dim according to permute args
...
...
@@ -259,9 +259,9 @@ class TraceIndice(object):
self
.
_assign_index_as_input
(
node
,
node_idx
,
input_node
)
for
idx
,
d
in
enumerate
(
permute_dim
):
self
.
_inherit_inde
x
(
input_node
,
d
,
node
,
idx
)
self
.
_inherit_ind
ic
e
(
input_node
,
d
,
node
,
idx
)
def
_assign_linear_inde
x
(
self
,
node
,
node_idx
):
def
_assign_linear_ind
ic
e
(
self
,
node
,
node_idx
):
"""
Assign index for linear op.
1. copy trace from input node and change last index accroding to weight
...
...
@@ -279,15 +279,15 @@ class TraceIndice(object):
input_node
,
weight
,
bias
=
node
.
args
self
.
_assign_index_as_input
(
node
,
node_idx
)
self
.
_inherit_inde
x
(
weight
,
1
,
node
,
-
1
)
self
.
_inherit_ind
ic
e
(
weight
,
1
,
node
,
-
1
)
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
_mark_i
dx
_equal
(
input_node
,
-
1
,
weight
,
0
)
self
.
_mark_i
ndice
_equal
(
input_node
,
-
1
,
weight
,
0
)
if
bias
:
self
.
_mark_i
dx
_equal
(
input_node
,
-
1
,
bias
,
0
)
self
.
_mark_i
ndice
_equal
(
input_node
,
-
1
,
bias
,
0
)
def
_assign_matmul_inde
x
(
self
,
node
,
node_idx
):
def
_assign_matmul_ind
ic
e
(
self
,
node
,
node_idx
):
"""
Assign index for matmul op.
1. copy trace from matmul_left and change last index accroding to matmul_right. (assert they have same length)
...
...
@@ -302,13 +302,13 @@ class TraceIndice(object):
assert
len
(
get_node_shape
(
matmul_left
))
==
len
(
get_node_shape
(
matmul_right
))
self
.
_assign_index_as_input
(
node
,
node_idx
,
matmul_left
)
self
.
_inherit_inde
x
(
matmul_right
,
-
1
,
node
,
-
1
)
self
.
_inherit_ind
ic
e
(
matmul_right
,
-
1
,
node
,
-
1
)
self
.
_mark_computation_from_node
(
matmul_right
,
node
,
[
-
1
,
-
2
])
self
.
_mark_computation
(
node
,
node_idx
,
[
-
1
])
self
.
_mark_i
dx
_equal
(
matmul_left
,
-
1
,
matmul_right
,
-
2
)
self
.
_mark_i
ndice
_equal
(
matmul_left
,
-
1
,
matmul_right
,
-
2
)
def
_assign_layernorm_inde
x
(
self
,
node
,
idx
):
def
_assign_layernorm_ind
ic
e
(
self
,
node
,
idx
):
"""
Assign index for layernorm op.
1. assign index as input node
...
...
@@ -321,7 +321,7 @@ class TraceIndice(object):
self
.
_assign_index_as_input
(
node
,
idx
)
self
.
_mark_computation
(
node
,
idx
,
[
-
1
])
def
_assign_elementwise_inde
x
(
self
,
node
,
idx
):
def
_assign_elementwise_ind
ic
e
(
self
,
node
,
idx
):
"""
Assign index for element-wise op (eg. relu sigmoid add mul).
1. assign index as input node
...
...
@@ -343,15 +343,15 @@ class TraceIndice(object):
node_in1_shape
=
get_node_shape
(
nodes_in
[
1
])
for
i
in
range
(
-
1
,
-
min
(
len
(
node_in0_shape
),
len
(
node_in1_shape
))
-
1
,
-
1
):
if
node_in0_shape
[
i
]
==
node_in1_shape
[
i
]:
self
.
_mark_i
dx
_equal
(
nodes_in
[
0
],
i
,
nodes_in
[
1
],
i
)
self
.
_mark_i
ndice
_equal
(
nodes_in
[
0
],
i
,
nodes_in
[
1
],
i
)
def
_assgin_no_change_inde
x
(
self
,
node
,
idx
):
def
_assgin_no_change_ind
ic
e
(
self
,
node
,
idx
):
self
.
_assign_index_as_input
(
node
,
idx
)
for
node_in
in
node
.
args
:
if
type
(
node_in
)
==
type
(
node
):
self
.
_mark_computation_from_node
(
node_in
,
node
)
def
_assign_einsum_inde
x
(
self
,
node
,
idx
):
def
_assign_einsum_ind
ic
e
(
self
,
node
,
idx
):
"""
Assign index for einsum op.
...
...
@@ -378,7 +378,7 @@ class TraceIndice(object):
for
left_idx
,
left_str
in
enumerate
(
left
):
if
right_indice
in
left_str
:
source_idx
=
left_str
.
index
(
right_indice
)
self
.
_inherit_inde
x
(
self
.
_inherit_ind
ic
e
(
input_nodes
[
left_idx
],
source_idx
,
node
,
right_idx
)
...
...
@@ -388,7 +388,7 @@ class TraceIndice(object):
# self._mark_computation(node, idx, left_str.index(i))
# break
def
_assign_softmax_inde
x
(
self
,
node
,
idx
):
def
_assign_softmax_ind
ic
e
(
self
,
node
,
idx
):
"""
Assign index for softmax op.
1. assign index as input node
...
...
@@ -401,7 +401,7 @@ class TraceIndice(object):
self
.
_assign_index_as_input
(
node
,
idx
)
self
.
_mark_computation
(
node
,
idx
,
[
node
.
kwargs
[
"dim"
]])
def
_assign_unsqueeze_inde
x
(
self
,
node
,
node_idx
):
def
_assign_unsqueeze_ind
ic
e
(
self
,
node
,
node_idx
):
"""
Assign index for unsqueeze op.
1. assign new index for unsqueeze dim
...
...
@@ -414,7 +414,7 @@ class TraceIndice(object):
self
.
_assign_index_as_input
(
node
,
node_idx
)
self
.
_add_dim
(
node_idx
,
node
.
args
[
1
])
def
_assign_dropout_inde
x
(
self
,
node
,
node_idx
):
def
_assign_dropout_ind
ic
e
(
self
,
node
,
node_idx
):
"""
Assign index for unsqueeze op.
1. assign new index for unsqueeze dim
...
...
@@ -425,7 +425,7 @@ class TraceIndice(object):
"""
self
.
_assign_index_as_input
(
node
,
node_idx
)
def
_assign_ones_like_inde
x
(
self
,
node
,
node_idx
):
def
_assign_ones_like_ind
ic
e
(
self
,
node
,
node_idx
):
"""
Assign index for oneslike op.
1. assign new index for all dim
...
...
@@ -434,9 +434,9 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
self
.
_assign_all_inde
x
(
node
,
node_idx
)
self
.
_assign_all_ind
ic
e
(
node
,
node_idx
)
def
_assign_view_reshape_inde
x
(
self
,
node
,
node_idx
):
def
_assign_view_reshape_ind
ic
e
(
self
,
node
,
node_idx
):
"""
Assign index for view and reshape op.
1. get origin shape and target shape by meta info.
...
...
@@ -496,7 +496,7 @@ class TraceIndice(object):
)
# get new index
origin_trace
=
self
.
_find_i
dx
_trace_from_node
(
origin_node
)
origin_trace
=
self
.
_find_i
ndice
_trace_from_node
(
origin_node
)
self
.
_assign_index_as_input
(
node
,
node_idx
,
origin_node
)
dim_from
.
reverse
()
for
i
in
dim_from
:
...
...
@@ -516,18 +516,18 @@ class TraceIndice(object):
view_dict
=
{
"idx_from"
:
[
origin_trace
[
i
]
for
i
in
dim_from
],
"dim_from"
:
dim_from
,
"idx_to"
:
[
self
.
i
dx
_trace_list
[
node_idx
][
"idx"
][
i
]
for
i
in
dim_to
],
"idx_to"
:
[
self
.
i
ndice
_trace_list
[
node_idx
][
"idx"
][
i
]
for
i
in
dim_to
],
"dim_to"
:
dim_to
,
}
self
.
i
dx
_view_list
[
node
]
=
view_dict
self
.
i
ndice
_view_list
[
node
]
=
view_dict
def
_merge_equal_idx
(
self
):
idx_equal
=
copy
.
deepcopy
(
self
.
i
dx
_trace_equal
)
idx_equal
=
copy
.
deepcopy
(
self
.
i
ndice
_trace_equal
)
idx_equal
.
reverse
()
for
idx
in
idx_equal
:
merge_to
=
min
(
idx
)
merge_from
=
max
(
idx
)
for
trace
in
self
.
i
dx
_trace_list
:
for
trace
in
self
.
i
ndice
_trace_list
:
if
merge_from
in
trace
[
"idx"
]:
trace
[
"idx"
]
=
[
merge_to
if
i
==
merge_from
else
i
for
i
in
trace
[
"idx"
]
...
...
@@ -536,35 +536,35 @@ class TraceIndice(object):
def
trace_index
(
self
):
for
idx
,
node
in
enumerate
(
self
.
node_list
):
if
node
.
op
==
"placeholder"
:
self
.
_assign_all_inde
x
(
node
,
idx
)
self
.
_assign_all_ind
ic
e
(
node
,
idx
)
elif
node
.
op
==
"call_method"
:
if
"transpose"
in
node
.
name
:
self
.
_assign_transpose_inde
x
(
node
,
idx
)
self
.
_assign_transpose_ind
ic
e
(
node
,
idx
)
elif
"permute"
in
node
.
name
:
self
.
_assign_permute_inde
x
(
node
,
idx
)
self
.
_assign_permute_ind
ic
e
(
node
,
idx
)
elif
"view"
in
node
.
name
or
"reshape"
in
node
.
name
:
self
.
_assign_view_reshape_inde
x
(
node
,
idx
)
self
.
_assign_view_reshape_ind
ic
e
(
node
,
idx
)
elif
"unsqueeze"
in
node
.
name
:
self
.
_assign_unsqueeze_inde
x
(
node
,
idx
)
self
.
_assign_unsqueeze_ind
ic
e
(
node
,
idx
)
elif
any
(
i
in
node
.
name
for
i
in
[
"to"
,
"contiguous"
]):
self
.
_assgin_no_change_inde
x
(
node
,
idx
)
self
.
_assgin_no_change_ind
ic
e
(
node
,
idx
)
else
:
raise
NotImplementedError
(
node
.
name
,
"method not implemented yet!"
)
elif
node
.
op
==
"call_function"
:
if
"linear"
in
node
.
name
:
self
.
_assign_linear_inde
x
(
node
,
idx
)
self
.
_assign_linear_ind
ic
e
(
node
,
idx
)
elif
"matmul"
in
node
.
name
:
self
.
_assign_matmul_inde
x
(
node
,
idx
)
self
.
_assign_matmul_ind
ic
e
(
node
,
idx
)
elif
"softmax"
in
node
.
name
:
self
.
_assign_softmax_inde
x
(
node
,
idx
)
self
.
_assign_softmax_ind
ic
e
(
node
,
idx
)
elif
any
(
n
in
node
.
name
for
n
in
[
"mul"
,
"add"
,
"sigmoid"
,
"relu"
]):
self
.
_assign_elementwise_inde
x
(
node
,
idx
)
self
.
_assign_elementwise_ind
ic
e
(
node
,
idx
)
elif
"ones_like"
in
node
.
name
:
self
.
_assign_ones_like_inde
x
(
node
,
idx
)
self
.
_assign_ones_like_ind
ic
e
(
node
,
idx
)
elif
"dropout"
in
node
.
name
:
self
.
_assign_dropout_inde
x
(
node
,
idx
)
self
.
_assign_dropout_ind
ic
e
(
node
,
idx
)
elif
"einsum"
in
node
.
name
:
self
.
_assign_einsum_inde
x
(
node
,
idx
)
self
.
_assign_einsum_ind
ic
e
(
node
,
idx
)
elif
"getattr"
in
node
.
name
:
continue
# get attr like shape
elif
"getitem"
in
node
.
name
:
...
...
@@ -575,11 +575,11 @@ class TraceIndice(object):
)
elif
node
.
op
==
"call_module"
:
if
any
(
n
in
node
.
name
for
n
in
[
"layernorm"
,
"norm"
]):
self
.
_assign_layernorm_inde
x
(
node
,
idx
)
self
.
_assign_layernorm_ind
ic
e
(
node
,
idx
)
else
:
raise
NotImplementedError
(
node
.
name
,
"module not implemented yet!"
)
elif
node
.
op
==
"get_attr"
:
self
.
_assign_all_inde
x
(
node
,
idx
)
# get param
self
.
_assign_all_ind
ic
e
(
node
,
idx
)
# get param
elif
node
.
op
==
"output"
:
continue
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment