Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8a989a0d
Commit
8a989a0d
authored
Jan 06, 2023
by
oahzxl
Browse files
code style
parent
c3a2bf48
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
29 deletions
+40
-29
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+40
-29
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
8a989a0d
...
...
@@ -98,6 +98,39 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
return
context
def
_replace_ones_like
(
search_chunk
,
chunk_infos
,
region_idx
,
node_idx
,
node
,
body
):
if
"ones_like"
in
node
.
name
:
meta_node
=
search_chunk
.
trace_index
.
node_list
[
node_idx
]
chunk_dim
=
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
"chunk_dim"
]
if
get_node_shape
(
meta_node
)[
chunk_dim
]
!=
1
:
source_node
=
meta_node
.
args
[
0
].
args
[
0
]
if
(
source_node
not
in
chunk_infos
[
region_idx
][
"node_chunk_dim"
]
or
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
source_node
][
"chunk_dim"
]
is
None
):
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"chunk_idx"
,
get_node_shape
(
node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
)
return
body
def
_replace_input_var
(
chunk_inputs
,
region_idx
,
chunk_inputs_dim
,
node_idx
,
body
):
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
if
idx
==
node_idx
:
chunk_slice
=
_gen_chunk_slice_dim
(
dim
[
0
],
"chunk_idx"
,
get_node_shape
(
input_node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
return
body
def
emit_code_with_chunk
(
body
,
nodes
,
...
...
@@ -156,36 +189,14 @@ def emit_code_with_chunk(
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
# replace input var with chunk var
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
if
idx
==
node_idx
:
chunk_slice
=
_gen_chunk_slice_dim
(
dim
[
0
],
"chunk_idx"
,
get_node_shape
(
input_node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
body
=
_replace_input_var
(
chunk_inputs
,
region_idx
,
chunk_inputs_dim
,
node_idx
,
body
)
# ones like
if
"ones_like"
in
node
.
name
:
meta_node
=
search_chunk
.
trace_index
.
node_list
[
node_idx
]
chunk_dim
=
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
"chunk_dim"
]
if
get_node_shape
(
meta_node
)[
chunk_dim
]
!=
1
:
source_node
=
meta_node
.
args
[
0
].
args
[
0
]
if
(
source_node
not
in
chunk_infos
[
region_idx
][
"node_chunk_dim"
]
or
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
source_node
][
"chunk_dim"
]
is
None
):
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"chunk_idx"
,
get_node_shape
(
node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
)
body
=
_replace_ones_like
(
search_chunk
,
chunk_infos
,
region_idx
,
node_idx
,
node
,
body
)
# reassgin reshape size
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
],
node
.
name
,
chunk_infos
[
region_idx
][
"reshape_size"
]
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment