Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
5d0fdb9c
Unverified
Commit
5d0fdb9c
authored
Sep 27, 2022
by
Boyuan Yao
Committed by
GitHub
Sep 27, 2022
Browse files
[fx] fix offload codegen test (#1648)
* [fx] fix offload codegen test * [fx] modify typing
parent
45b39a69
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
12 deletions
+12
-12
colossalai/fx/codegen/activation_checkpoint_codegen.py
colossalai/fx/codegen/activation_checkpoint_codegen.py
+4
-4
tests/test_fx/test_codegen/test_offload_codegen.py
tests/test_fx/test_codegen/test_offload_codegen.py
+8
-8
No files found.
colossalai/fx/codegen/activation_checkpoint_codegen.py
View file @
5d0fdb9c
import
colossalai
import
colossalai
import
torch
import
torch
from
typing
import
List
,
Callable
,
Any
,
Tuple
,
Dict
from
typing
import
List
,
Callable
,
Any
,
Tuple
,
Dict
,
Iterable
try
:
try
:
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
...
@@ -157,7 +157,7 @@ def _find_offload_regions(nodes: List[Node]):
...
@@ -157,7 +157,7 @@ def _find_offload_regions(nodes: List[Node]):
current_region
=
None
current_region
=
None
for
idx
,
node
in
enumerate
(
nodes
):
for
idx
,
node
in
enumerate
(
nodes
):
if
hasattr
(
node
,
'activation_offload'
)
and
isinstance
(
getattr
(
node
,
'activation_offload'
,
None
),
list
):
if
hasattr
(
node
,
'activation_offload'
)
and
isinstance
(
getattr
(
node
,
'activation_offload'
,
None
),
Iterable
):
act_offload_label
=
node
.
activation_offload
act_offload_label
=
node
.
activation_offload
if
current_region
==
None
:
if
current_region
==
None
:
...
@@ -796,7 +796,7 @@ if CODEGEN_AVAILABLE:
...
@@ -796,7 +796,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
# will use nested type of activation checkpoint codegen
if
any
(
isinstance
(
getattr
(
node
,
"activation_checkpoint"
,
None
),
list
)
for
node
in
nodes
):
if
any
(
isinstance
(
getattr
(
node
,
"activation_checkpoint"
,
None
),
Iterable
)
for
node
in
nodes
):
emit_code_with_nested_activation_checkpoint
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
)
emit_code_with_nested_activation_checkpoint
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
)
else
:
else
:
emit_code_with_activation_checkpoint
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
)
emit_code_with_activation_checkpoint
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
)
...
@@ -999,7 +999,7 @@ else:
...
@@ -999,7 +999,7 @@ else:
# if any node has a list of labels for activation_checkpoint, we
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
# will use nested type of activation checkpoint codegen
if
any
(
isinstance
(
getattr
(
node
,
"activation_checkpoint"
,
None
),
list
)
for
node
in
self
.
nodes
):
if
any
(
isinstance
(
getattr
(
node
,
"activation_checkpoint"
,
None
),
Iterable
)
for
node
in
self
.
nodes
):
emit_code_with_nested_activation_checkpoint
(
body
,
ckpt_func
,
self
.
nodes
,
emit_node
,
delete_unused_values
)
emit_code_with_nested_activation_checkpoint
(
body
,
ckpt_func
,
self
.
nodes
,
emit_node
,
delete_unused_values
)
else
:
else
:
emit_code_with_activation_checkpoint
(
body
,
ckpt_func
,
self
.
nodes
,
emit_node
,
delete_unused_values
)
emit_code_with_activation_checkpoint
(
body
,
ckpt_func
,
self
.
nodes
,
emit_node
,
delete_unused_values
)
...
...
tests/test_fx/test_codegen/test_offload_codegen.py
View file @
5d0fdb9c
...
@@ -83,13 +83,13 @@ def _run_offload_codegen(rank):
...
@@ -83,13 +83,13 @@ def _run_offload_codegen(rank):
# of input offload
# of input offload
for
node
in
graph
.
nodes
:
for
node
in
graph
.
nodes
:
if
node
.
name
==
"linear0"
:
if
node
.
name
==
"linear0"
:
setattr
(
node
,
"activation_offload"
,
(
0
,
True
,
False
)
)
setattr
(
node
,
"activation_offload"
,
[
0
,
True
,
False
]
)
if
node
.
name
==
"linear1"
:
if
node
.
name
==
"linear1"
:
setattr
(
node
,
"activation_offload"
,
(
0
,
True
,
False
)
)
setattr
(
node
,
"activation_offload"
,
[
0
,
True
,
False
]
)
if
node
.
name
==
"linear2"
:
if
node
.
name
==
"linear2"
:
setattr
(
node
,
"activation_offload"
,
(
1
,
True
,
True
)
)
setattr
(
node
,
"activation_offload"
,
[
1
,
True
,
True
]
)
if
node
.
name
==
"linear4"
:
if
node
.
name
==
"linear4"
:
setattr
(
node
,
"activation_offload"
,
(
2
,
False
,
True
)
)
setattr
(
node
,
"activation_offload"
,
[
2
,
False
,
True
]
)
if
node
.
name
==
"linear5"
:
if
node
.
name
==
"linear5"
:
setattr
(
node
,
"activation_checkpoint"
,
[
0
])
setattr
(
node
,
"activation_checkpoint"
,
[
0
])
setattr
(
node
,
"activation_offload"
,
True
)
setattr
(
node
,
"activation_offload"
,
True
)
...
@@ -138,13 +138,13 @@ def _run_offload_codegen_torch11(rank):
...
@@ -138,13 +138,13 @@ def _run_offload_codegen_torch11(rank):
# of input offload
# of input offload
for
node
in
graph
.
nodes
:
for
node
in
graph
.
nodes
:
if
node
.
name
==
"linear0"
:
if
node
.
name
==
"linear0"
:
setattr
(
node
,
"activation_offload"
,
(
0
,
True
,
False
)
)
setattr
(
node
,
"activation_offload"
,
[
0
,
True
,
False
]
)
if
node
.
name
==
"linear1"
:
if
node
.
name
==
"linear1"
:
setattr
(
node
,
"activation_offload"
,
(
0
,
True
,
False
)
)
setattr
(
node
,
"activation_offload"
,
[
0
,
True
,
False
]
)
if
node
.
name
==
"linear2"
:
if
node
.
name
==
"linear2"
:
setattr
(
node
,
"activation_offload"
,
(
1
,
True
,
True
)
)
setattr
(
node
,
"activation_offload"
,
[
1
,
True
,
True
]
)
if
node
.
name
==
"linear4"
:
if
node
.
name
==
"linear4"
:
setattr
(
node
,
"activation_offload"
,
(
2
,
False
,
True
)
)
setattr
(
node
,
"activation_offload"
,
[
2
,
False
,
True
]
)
if
node
.
name
==
"linear5"
:
if
node
.
name
==
"linear5"
:
setattr
(
node
,
"activation_checkpoint"
,
[
0
])
setattr
(
node
,
"activation_checkpoint"
,
[
0
])
setattr
(
node
,
"activation_offload"
,
True
)
setattr
(
node
,
"activation_offload"
,
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment