Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b7b67c32
Commit
b7b67c32
authored
Dec 12, 2022
by
oahzxl
Browse files
code style
parent
31a2c5d0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
42 deletions
+28
-42
chunk_codegen.py
chunk_codegen.py
+28
-42
No files found.
chunk_codegen.py
View file @
b7b67c32
...
...
@@ -232,7 +232,7 @@ class FlowTracer(object):
inputs_dim
=
[]
remove_inputs
=
[]
for
input_node
in
chunk_info
[
'
inputs
'
]:
for
input_node
in
chunk_info
[
"
inputs
"
]:
input_dict
=
{}
for
user
in
input_node
.
users
.
keys
():
if
_is_non_compute_node
(
user
):
...
...
@@ -252,15 +252,17 @@ class FlowTracer(object):
remove_inputs
.
append
(
input_node
)
else
:
inputs_dim
.
append
(
input_dict
)
chunk_info
[
'
inputs_dim
'
]
=
inputs_dim
chunk_info
[
"
inputs_dim
"
]
=
inputs_dim
for
i
in
remove_inputs
:
if
i
in
chunk_info
[
'
inputs
'
]:
chunk_info
[
'
inputs
'
].
remove
(
i
)
if
i
in
chunk_info
[
"
inputs
"
]:
chunk_info
[
"
inputs
"
].
remove
(
i
)
# we need to log input nodes to avoid deleteing them in the loop
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
])
non_chunk_inputs
=
_find_chunk_all_input_nodes
(
self
.
node_list
[
start_idx
:
end_idx
+
1
]
)
for
i
in
non_chunk_inputs
:
if
i
not
in
chunk_info
[
'
inputs
'
]:
if
i
not
in
chunk_info
[
"
inputs
"
]:
chunk_info
[
"inputs_non_chunk"
].
append
(
i
)
return
flow_flag
,
chunk_info
...
...
@@ -1371,44 +1373,32 @@ def _get_first_non_single_dim(shape):
def
_gen_loop_start
(
chunk_input
,
chunk_output
,
chunk_ouput_dim
,
chunk_size
=
2
):
input_node
=
chunk_input
[
0
]
out_shape
=
_get_node_shape
(
chunk_output
)
out_str
=
str
(
list
(
out_shape
))
context
=
(
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d
\n
for chunk_idx in range"
%
(
out_str
,
input_node
.
name
,
input_node
.
name
,
chunk_size
)
)
context
+=
"(0, %d, chunk_size):
\n
"
%
(
out_shape
[
chunk_ouput_dim
])
# node = chunk_input[0]
# node_shape = node.meta["tensor_meta"].shape
# free_shape = [
# node_shape[i] if i in chunk_dim else 1 for i in range(len(node_shape))
# ]
# chunk_dim = _get_first_non_single_dim(free_shape)
# chunk_slice = _gen_chunk_slice_dim(chunk_dim, "gen_chunk_idx", node_shape)
# out_shape = str(list(chunk_output.meta["tensor_meta"].shape))
# context = (
# "chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor gen_chunk_idx in range"
# % (out_shape, node.name, node.name, chunk_size)
# )
# context += "(0, %s.shape[%d], chunk_size):\n" % (node.name, chunk_dim)
# context += " chunk_tensor = %s%s\n" % (node.name, chunk_slice)
return
context
def
_gen_loop_end
(
chunk_inputs
,
chunk_non_compute_inputs
,
chunk_outputs
,
chunk_outputs_dim
,
node_list
):
def
_gen_loop_end
(
chunk_inputs
,
chunk_non_compute_inputs
,
chunk_outputs
,
chunk_outputs_dim
,
node_list
):
chunk_outputs_name
=
chunk_outputs
.
name
chunk_outputs_idx
=
_find_idx_by_name
(
chunk_outputs_name
,
node_list
)
chunk_output_shape
=
chunk_outputs
.
meta
[
"tensor_meta"
].
shape
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
)
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
)
context
=
" chunk_result%s = %s
\n
"
%
(
chunk_slice
,
chunk_outputs_name
)
context
+=
(
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
)
context
+=
(
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
)
# determine if its the last use for chunk input
for
chunk_input
in
(
chunk_inputs
+
chunk_non_compute_inputs
)
:
for
chunk_input
in
chunk_inputs
+
chunk_non_compute_inputs
:
if
all
(
[
_find_idx_by_name
(
user
.
name
,
node_list
)
<=
chunk_outputs_idx
...
...
@@ -1456,10 +1446,7 @@ def _find_chunk_all_input_nodes(nodes: List[Node]):
input_nodes
=
[]
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
if
(
input_node
not
in
nodes
and
input_node
not
in
input_nodes
):
if
input_node
not
in
nodes
and
input_node
not
in
input_nodes
:
input_nodes
.
append
(
input_node
)
return
input_nodes
...
...
@@ -1549,16 +1536,12 @@ def emit_code_with_chunk(
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_search
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_search
]
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_search
]
chunk_inputs_
idx
=
[
[
_find_idx_by_name
(
j
.
name
,
node_list
)
for
j
in
i
]
for
i
in
chunk_inputs
chunk_inputs_
names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs
_non_chunk
for
j
in
i
]
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_search
]
chunk_outputs_dim
=
[
i
[
"outputs_dim"
]
for
i
in
chunk_search
]
chunk_outputs_idx
=
[
_find_idx_by_name
(
i
.
name
,
node_list
)
for
i
in
chunk_outputs
]
node_idx
=
0
region_idx
=
0
...
...
@@ -1586,7 +1569,9 @@ def emit_code_with_chunk(
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
if
idx
==
node_idx
:
chunk_slice
=
_gen_chunk_slice_dim
(
dim
,
"chunk_idx"
,
_get_node_shape
(
input_node
))
chunk_slice
=
_gen_chunk_slice_dim
(
dim
,
"chunk_idx"
,
_get_node_shape
(
input_node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
...
...
@@ -1604,7 +1589,8 @@ def emit_code_with_chunk(
chunk_inputs
[
region_idx
],
chunk_inputs_non_chunk
[
region_idx
],
chunk_outputs
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
node_list
chunk_outputs_dim
[
region_idx
],
node_list
,
)
)
within_chunk_region
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment