Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
00778abc
Commit
00778abc
authored
Mar 27, 2023
by
CsRic
Committed by
binmakeswell
Mar 29, 2023
Browse files
[NFC] polish colossalai/fx/passes/split_module.py code style (#3263)
Co-authored-by:
csric
<
richcsr256@gmail.com
>
parent
488f3704
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
10 deletions
+9
-10
colossalai/fx/passes/split_module.py
colossalai/fx/passes/split_module.py
+9
-10
No files found.
colossalai/fx/passes/split_module.py
View file @
00778abc
import
inspect
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch
from
torch.fx.graph_module
import
GraphModule
from
typing
import
Callable
,
List
,
Dict
,
Any
,
Optional
from
torch.fx._compatibility
import
compatibility
from
packaging
import
version
from
packaging
import
version
import
inspect
from
torch.fx._compatibility
import
compatibility
from
torch.fx.graph_module
import
GraphModule
@
compatibility
(
is_backward_compatible
=
True
)
@
compatibility
(
is_backward_compatible
=
True
)
...
@@ -38,7 +39,7 @@ def split_module(
...
@@ -38,7 +39,7 @@ def split_module(
m
:
GraphModule
,
m
:
GraphModule
,
root_m
:
torch
.
nn
.
Module
,
root_m
:
torch
.
nn
.
Module
,
split_callback
:
Callable
[[
torch
.
fx
.
node
.
Node
],
int
],
split_callback
:
Callable
[[
torch
.
fx
.
node
.
Node
],
int
],
merge_output
=
False
,
merge_output
=
False
,
):
):
"""
"""
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
...
@@ -132,10 +133,8 @@ def split_module(
...
@@ -132,10 +133,8 @@ def split_module(
use_partition
.
inputs
.
setdefault
(
def_node
.
name
)
use_partition
.
inputs
.
setdefault
(
def_node
.
name
)
if
def_partition_name
is
not
None
:
if
def_partition_name
is
not
None
:
use_partition
.
partitions_dependent_on
.
setdefault
(
def_partition_name
)
use_partition
.
partitions_dependent_on
.
setdefault
(
def_partition_name
)
def
record_output
(
def
record_output
(
def_node
:
torch
.
fx
.
node
.
Node
,
use_node
:
Optional
[
torch
.
fx
.
node
.
Node
]):
# noqa: B950
def_node
:
torch
.
fx
.
node
.
Node
,
use_node
:
Optional
[
torch
.
fx
.
node
.
Node
]
):
# noqa: B950
def_partition_name
=
getattr
(
def_node
,
"_fx_partition"
,
None
)
def_partition_name
=
getattr
(
def_node
,
"_fx_partition"
,
None
)
use_partition_name
=
getattr
(
use_node
,
"_fx_partition"
,
None
)
use_partition_name
=
getattr
(
use_node
,
"_fx_partition"
,
None
)
if
def_partition_name
!=
use_partition_name
:
if
def_partition_name
!=
use_partition_name
:
...
@@ -291,7 +290,7 @@ def split_module(
...
@@ -291,7 +290,7 @@ def split_module(
for
partition_name
in
sorted_partitions
:
for
partition_name
in
sorted_partitions
:
partition
=
partitions
[
partition_name
]
partition
=
partitions
[
partition_name
]
new_gm
=
torch
.
fx
.
graph_module
.
GraphModule
(
base_mod_attrs
,
base_mod_graph
)
new_gm
=
torch
.
fx
.
graph_module
.
GraphModule
(
base_mod_attrs
,
base_mod_graph
)
return
new_gm
return
new_gm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment