Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
70a98b8f
Commit
70a98b8f
authored
Nov 14, 2022
by
oahzxl
Browse files
add doc string
parent
c36dba07
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
2 deletions
+57
-2
chunk_codegen.py
chunk_codegen.py
+57
-2
No files found.
chunk_codegen.py
View file @
70a98b8f
...
@@ -26,13 +26,28 @@ class NodeIndexTracer(object):
...
@@ -26,13 +26,28 @@ class NodeIndexTracer(object):
self
.
idx_trace_list
=
[{
'idx'
:
[],
'compute'
:
[]}
for
_
in
range
(
len
(
self
.
nodes_list
))]
self
.
idx_trace_list
=
[{
'idx'
:
[],
'compute'
:
[]}
for
_
in
range
(
len
(
self
.
nodes_list
))]
self
.
idx_trace_equal
=
[]
self
.
idx_trace_equal
=
[]
self
.
idx_view_list
=
[]
self
.
idx_view_list
=
[]
self
.
idx_count
=
1
self
.
idx_count
=
-
1
def
add_index
(
self
):
def
add_index
(
self
):
"""
Update the count and return it. To record the idx number.
Returns:
idx_count: int
"""
self
.
idx_count
+=
1
self
.
idx_count
+=
1
return
self
.
idx_count
-
1
return
self
.
idx_count
def
inherit_computation
(
self
,
node_from
,
node_to
):
def
inherit_computation
(
self
,
node_from
,
node_to
):
"""
Inherit computed dim from node_from to node_to.
If a dim in node_from is marked as computed and exists in node_to,
still mark it as computed in node_to.
Args:
node_from (node): node to be inherited
node_to (node): new node to inherit
"""
_
,
compute_from
=
self
.
find_trace_from_node
(
node_from
)
_
,
compute_from
=
self
.
find_trace_from_node
(
node_from
)
idx_to
,
compute_to
=
self
.
find_trace_from_node
(
node_to
)
idx_to
,
compute_to
=
self
.
find_trace_from_node
(
node_to
)
for
i
in
compute_from
:
for
i
in
compute_from
:
...
@@ -40,9 +55,24 @@ class NodeIndexTracer(object):
...
@@ -40,9 +55,24 @@ class NodeIndexTracer(object):
compute_to
.
append
(
i
)
compute_to
.
append
(
i
)
def
mark_idx_equal
(
self
,
idx1
,
idx2
):
def
mark_idx_equal
(
self
,
idx1
,
idx2
):
"""
Mark 2 index to be equal.
Args:
idx1 (int): index count.
idx2 (int): index count.
"""
self
.
idx_trace_equal
.
append
((
idx1
,
idx2
))
self
.
idx_trace_equal
.
append
((
idx1
,
idx2
))
def
mark_computation
(
self
,
node
,
idx
,
dim
):
def
mark_computation
(
self
,
node
,
idx
,
dim
):
"""
Mark some dims of node as computed.
Args:
node (node)
idx (int): node index
dim (list or int): dims to be marked as computed
"""
input_node_idx_trace
=
self
.
find_idx_trace_from_node
(
node
)
input_node_idx_trace
=
self
.
find_idx_trace_from_node
(
node
)
if
isinstance
(
dim
,
int
):
if
isinstance
(
dim
,
int
):
dim
=
[
dim
]
dim
=
[
dim
]
...
@@ -52,15 +82,40 @@ class NodeIndexTracer(object):
...
@@ -52,15 +82,40 @@ class NodeIndexTracer(object):
self
.
idx_trace_list
[
idx
][
'compute'
].
append
(
cur_idx
)
self
.
idx_trace_list
[
idx
][
'compute'
].
append
(
cur_idx
)
def
find_trace_from_node
(
self
,
node
):
def
find_trace_from_node
(
self
,
node
):
"""
Find node idx and compute trace by the node.
Args:
node (node)
Returns:
idx (list): idx of the node
compute (list): computed idx of the node.
"""
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
node_dict
=
self
.
idx_trace_list
[
node_idx
]
node_dict
=
self
.
idx_trace_list
[
node_idx
]
return
node_dict
[
'idx'
],
node_dict
[
'compute'
]
return
node_dict
[
'idx'
],
node_dict
[
'compute'
]
def
find_idx_trace_from_node
(
self
,
node
):
def
find_idx_trace_from_node
(
self
,
node
):
"""
Find node idx trace by the node.
Args:
node (node)
Returns:
idx (list): idx of the node
"""
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
return
self
.
idx_trace_list
[
node_idx
][
'idx'
]
return
self
.
idx_trace_list
[
node_idx
][
'idx'
]
def
find_compute_trace_from_node
(
self
,
node
):
def
find_compute_trace_from_node
(
self
,
node
):
"""
Find node compute trace by the node.
Args:
node (node)
Returns:
compute (list): computed idx of the node.
"""
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
node_idx
=
_find_idx_by_name
(
node
.
name
,
self
.
nodes_list
)
return
self
.
idx_trace_list
[
node_idx
][
'compute'
]
return
self
.
idx_trace_list
[
node_idx
][
'compute'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment