Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
212b5b1b
Commit
212b5b1b
authored
Jan 09, 2023
by
oahzxl
Browse files
add comments
parent
19cc64b1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
15 deletions
+22
-15
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+21
-14
tests/test_autochunk/test_autochunk_codegen.py
tests/test_autochunk/test_autochunk_codegen.py
+1
-1
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
212b5b1b
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Tuple
import
torch
from
torch.fx.graph
import
(
...
...
@@ -128,37 +128,42 @@ def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, bod
def
emit_code_with_chunk
(
body
,
nodes
,
body
:
List
[
str
]
,
nodes
:
Iterable
[
Node
]
,
emit_node_func
,
delete_unused_value_func
,
search_chunk
:
SearchChunk
,
chunk_infos
,
chunk_infos
:
List
,
):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
"""
Emit code with chunk according to chunk_infos.
It will generate a for loop in chunk regions, and replace inputs
and outputs of regions with chunked variables.
Args:
body: forward code
ckpt_func: checkpoint functions code
nodes: graph.nodes
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
search_chunk: the class to search all chunks
chunk_infos: store all information about all chunks.
"""
node_list
=
list
(
nodes
)
chunk
_regions
=
[
i
[
"region"
]
for
i
in
chunk_infos
]
chunk_starts
=
[
i
[
0
]
for
i
in
chunk_
region
s
]
chunk_ends
=
[
i
[
1
]
for
i
in
chunk_
region
s
]
#
chunk
region
chunk_starts
=
[
i
[
"region"
][
0
]
for
i
in
chunk_
info
s
]
chunk_ends
=
[
i
[
"region"
][
1
]
for
i
in
chunk_
info
s
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_infos
]
# chunk inputs
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
# input with chunk
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
# input without chunk
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_infos
]
# input chunk dim
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
# chunk outputs
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_outputs_dim
=
[
i
[
"outputs_dim"
]
for
i
in
chunk_infos
]
...
...
@@ -170,6 +175,7 @@ def emit_code_with_chunk(
while
node_idx
<
len
(
node_list
):
node
=
node_list
[
node_idx
]
# if is chunk start, generate for loop start
if
node_idx
in
chunk_starts
:
within_chunk_region
=
True
region_idx
=
chunk_starts
.
index
(
node_idx
)
...
...
@@ -203,6 +209,7 @@ def emit_code_with_chunk(
if
node_idx
not
in
chunk_inputs
:
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
# generate chunk region end
if
node_idx
in
chunk_ends
:
body
.
append
(
_gen_loop_end
(
...
...
tests/test_autochunk/test_autochunk_codegen.py
View file @
212b5b1b
...
...
@@ -115,4 +115,4 @@ def test_autochunk_codegen(msa_len, pair_len, max_memory):
if
__name__
==
"__main__"
:
_test_autochunk_codegen
(
0
,
32
,
64
,
None
)
_test_autochunk_codegen
(
0
,
32
,
64
,
25
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment