Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a4ed5b0d
Commit
a4ed5b0d
authored
Jan 09, 2023
by
oahzxl
Browse files
rename in doc
parent
1bb1f2ad
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
37 deletions
+29
-37
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+29
-37
No files found.
colossalai/autochunk/trace_indice.py
View file @
a4ed5b0d
...
@@ -33,7 +33,7 @@ class TraceIndice(object):
...
@@ -33,7 +33,7 @@ class TraceIndice(object):
Update the count and return it. To record the idx number.
Update the count and return it. To record the idx number.
Returns:
Returns:
i
dx
_count: int
i
ndice
_count: int
"""
"""
self
.
indice_count
+=
1
self
.
indice_count
+=
1
return
self
.
indice_count
return
self
.
indice_count
...
@@ -113,11 +113,11 @@ class TraceIndice(object):
...
@@ -113,11 +113,11 @@ class TraceIndice(object):
def
_mark_indice_equal
(
self
,
node1
,
dim1
,
node2
,
dim2
):
def
_mark_indice_equal
(
self
,
node1
,
dim1
,
node2
,
dim2
):
"""
"""
Mark 2 inde
x
to be equal.
Mark 2 ind
ic
e to be equal.
Args:
Args:
idx1 (int): inde
x
count.
idx1 (int): ind
ic
e count.
idx2 (int): inde
x
count.
idx2 (int): ind
ic
e count.
"""
"""
# node1_idx = _find_idx_by_name(node1.name, self.nodes_list)
# node1_idx = _find_idx_by_name(node1.name, self.nodes_list)
# node2_idx = _find_idx_by_name(node2.name, self.nodes_list)
# node2_idx = _find_idx_by_name(node2.name, self.nodes_list)
...
@@ -215,7 +215,7 @@ class TraceIndice(object):
...
@@ -215,7 +215,7 @@ class TraceIndice(object):
def
_assign_all_indice
(
self
,
node
,
node_idx
):
def
_assign_all_indice
(
self
,
node
,
node_idx
):
"""
"""
Add new inde
x
for all node's dims.
Add new ind
ic
e for all node's dims.
Args:
Args:
node (node)
node (node)
...
@@ -229,7 +229,7 @@ class TraceIndice(object):
...
@@ -229,7 +229,7 @@ class TraceIndice(object):
def
_assign_transpose_indice
(
self
,
node
,
node_idx
):
def
_assign_transpose_indice
(
self
,
node
,
node_idx
):
"""
"""
Assign inde
x
for transpose op.
Assign ind
ic
e for transpose op.
1. swap input's dim according to transpose args
1. swap input's dim according to transpose args
2. inherit input's computation
2. inherit input's computation
...
@@ -246,7 +246,7 @@ class TraceIndice(object):
...
@@ -246,7 +246,7 @@ class TraceIndice(object):
def
_assign_permute_indice
(
self
,
node
,
node_idx
):
def
_assign_permute_indice
(
self
,
node
,
node_idx
):
"""
"""
Assign inde
x
for permute op.
Assign ind
ic
e for permute op.
1. swap input's dim according to permute args
1. swap input's dim according to permute args
2. inherit input's computation
2. inherit input's computation
...
@@ -263,9 +263,9 @@ class TraceIndice(object):
...
@@ -263,9 +263,9 @@ class TraceIndice(object):
def
_assign_linear_indice
(
self
,
node
,
node_idx
):
def
_assign_linear_indice
(
self
,
node
,
node_idx
):
"""
"""
Assign inde
x
for linear op.
Assign ind
ic
e for linear op.
1. copy trace from input node and change last inde
x
accroding to weight
1. copy trace from input node and change last ind
ic
e accroding to weight
2. mark equal for input node last inde
x
, weight first dim and bias dim.
2. mark equal for input node last ind
ic
e, weight first dim and bias dim.
3. inherit input's computation, mark computation for last dim.
3. inherit input's computation, mark computation for last dim.
Args:
Args:
...
@@ -289,9 +289,9 @@ class TraceIndice(object):
...
@@ -289,9 +289,9 @@ class TraceIndice(object):
def
_assign_matmul_indice
(
self
,
node
,
node_idx
):
def
_assign_matmul_indice
(
self
,
node
,
node_idx
):
"""
"""
Assign inde
x
for matmul op.
Assign ind
ic
e for matmul op.
1. copy trace from matmul_left and change last inde
x
accroding to matmul_right. (assert they have same length)
1. copy trace from matmul_left and change last ind
ic
e accroding to matmul_right. (assert they have same length)
2. mark equal for input matmul_left -1 inde
x
and matmul_right -2 dim.
2. mark equal for input matmul_left -1 ind
ic
e and matmul_right -2 dim.
3. inherit matmul_left and matmul_right computation, mark computation for last dim.
3. inherit matmul_left and matmul_right computation, mark computation for last dim.
Args:
Args:
...
@@ -310,8 +310,8 @@ class TraceIndice(object):
...
@@ -310,8 +310,8 @@ class TraceIndice(object):
def
_assign_layernorm_indice
(
self
,
node
,
idx
):
def
_assign_layernorm_indice
(
self
,
node
,
idx
):
"""
"""
Assign inde
x
for layernorm op.
Assign ind
ic
e for layernorm op.
1. assign inde
x
as input node
1. assign ind
ic
e as input node
2. inherit computation and mark last 2 dims as computed.
2. inherit computation and mark last 2 dims as computed.
Args:
Args:
...
@@ -323,8 +323,8 @@ class TraceIndice(object):
...
@@ -323,8 +323,8 @@ class TraceIndice(object):
def
_assign_elementwise_indice
(
self
,
node
,
idx
):
def
_assign_elementwise_indice
(
self
,
node
,
idx
):
"""
"""
Assign inde
x
for element-wise op (eg. relu sigmoid add mul).
Assign ind
ic
e for element-wise op (eg. relu sigmoid add mul).
1. assign inde
x
as input node
1. assign ind
ic
e as input node
2. inherit computation from all input nodes.
2. inherit computation from all input nodes.
Args:
Args:
...
@@ -353,7 +353,7 @@ class TraceIndice(object):
...
@@ -353,7 +353,7 @@ class TraceIndice(object):
def
_assign_einsum_indice
(
self
,
node
,
idx
):
def
_assign_einsum_indice
(
self
,
node
,
idx
):
"""
"""
Assign inde
x
for einsum op.
Assign ind
ic
e for einsum op.
Args:
Args:
node (node)
node (node)
...
@@ -371,8 +371,6 @@ class TraceIndice(object):
...
@@ -371,8 +371,6 @@ class TraceIndice(object):
for
c
in
i
:
for
c
in
i
:
all_index
.
append
(
c
)
all_index
.
append
(
c
)
all_index
=
set
(
all_index
)
all_index
=
set
(
all_index
)
free_index
=
set
([
i
for
i
in
right
])
sum_index
=
all_index
-
free_index
for
right_idx
,
right_indice
in
enumerate
(
right
):
for
right_idx
,
right_indice
in
enumerate
(
right
):
for
left_idx
,
left_str
in
enumerate
(
left
):
for
left_idx
,
left_str
in
enumerate
(
left
):
...
@@ -382,16 +380,10 @@ class TraceIndice(object):
...
@@ -382,16 +380,10 @@ class TraceIndice(object):
input_nodes
[
left_idx
],
source_idx
,
node
,
right_idx
input_nodes
[
left_idx
],
source_idx
,
node
,
right_idx
)
)
# for i in sum_index:
# for left_idx, left_str in enumerate(left):
# if i in left_str:
# self._mark_computation(node, idx, left_str.index(i))
# break
def
_assign_softmax_indice
(
self
,
node
,
idx
):
def
_assign_softmax_indice
(
self
,
node
,
idx
):
"""
"""
Assign inde
x
for softmax op.
Assign ind
ic
e for softmax op.
1. assign inde
x
as input node
1. assign ind
ic
e as input node
2. inherit computation and mark softmax dim as computed.
2. inherit computation and mark softmax dim as computed.
Args:
Args:
...
@@ -403,8 +395,8 @@ class TraceIndice(object):
...
@@ -403,8 +395,8 @@ class TraceIndice(object):
def
_assign_unsqueeze_indice
(
self
,
node
,
node_idx
):
def
_assign_unsqueeze_indice
(
self
,
node
,
node_idx
):
"""
"""
Assign inde
x
for unsqueeze op.
Assign ind
ic
e for unsqueeze op.
1. assign new inde
x
for unsqueeze dim
1. assign new ind
ic
e for unsqueeze dim
Args:
Args:
node (node)
node (node)
...
@@ -416,8 +408,8 @@ class TraceIndice(object):
...
@@ -416,8 +408,8 @@ class TraceIndice(object):
def
_assign_dropout_indice
(
self
,
node
,
node_idx
):
def
_assign_dropout_indice
(
self
,
node
,
node_idx
):
"""
"""
Assign inde
x
for unsqueeze op.
Assign ind
ic
e for unsqueeze op.
1. assign new inde
x
for unsqueeze dim
1. assign new ind
ic
e for unsqueeze dim
Args:
Args:
node (node)
node (node)
...
@@ -427,8 +419,8 @@ class TraceIndice(object):
...
@@ -427,8 +419,8 @@ class TraceIndice(object):
def
_assign_ones_like_indice
(
self
,
node
,
node_idx
):
def
_assign_ones_like_indice
(
self
,
node
,
node_idx
):
"""
"""
Assign inde
x
for oneslike op.
Assign ind
ic
e for oneslike op.
1. assign new inde
x
for all dim
1. assign new ind
ic
e for all dim
Args:
Args:
node (node)
node (node)
...
@@ -438,10 +430,10 @@ class TraceIndice(object):
...
@@ -438,10 +430,10 @@ class TraceIndice(object):
def
_assign_view_reshape_indice
(
self
,
node
,
node_idx
):
def
_assign_view_reshape_indice
(
self
,
node
,
node_idx
):
"""
"""
Assign inde
x
for view and reshape op.
Assign ind
ic
e for view and reshape op.
1. get origin shape and target shape by meta info.
1. get origin shape and target shape by meta info.
2. compute the real value of -1 in target shape.
2. compute the real value of -1 in target shape.
3. determine changed dim, and assgin inde
x
for generated dim.
3. determine changed dim, and assgin ind
ic
e for generated dim.
4. log changed dim and generated dim for restore
4. log changed dim and generated dim for restore
5. inherit computation.
5. inherit computation.
6. TODO: look into view list to see whether the view is associated with other,
6. TODO: look into view list to see whether the view is associated with other,
...
@@ -495,7 +487,7 @@ class TraceIndice(object):
...
@@ -495,7 +487,7 @@ class TraceIndice(object):
+
"view not implemented"
+
"view not implemented"
)
)
# get new inde
x
# get new ind
ic
e
origin_trace
=
self
.
_find_indice_trace_from_node
(
origin_node
)
origin_trace
=
self
.
_find_indice_trace_from_node
(
origin_node
)
self
.
_assign_indice_as_input
(
node
,
node_idx
,
origin_node
)
self
.
_assign_indice_as_input
(
node
,
node_idx
,
origin_node
)
dim_from
.
reverse
()
dim_from
.
reverse
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment