Commit 00778abc authored by CsRic's avatar CsRic Committed by binmakeswell
Browse files

[NFC] polish colossalai/fx/passes/split_module.py code style (#3263)


Co-authored-by: default avatarcsric <richcsr256@gmail.com>
parent 488f3704
import inspect
from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
from packaging import version from packaging import version
import inspect from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
...@@ -38,7 +39,7 @@ def split_module( ...@@ -38,7 +39,7 @@ def split_module(
m: GraphModule, m: GraphModule,
root_m: torch.nn.Module, root_m: torch.nn.Module,
split_callback: Callable[[torch.fx.node.Node], int], split_callback: Callable[[torch.fx.node.Node], int],
merge_output = False, merge_output=False,
): ):
""" """
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
...@@ -132,10 +133,8 @@ def split_module( ...@@ -132,10 +133,8 @@ def split_module(
use_partition.inputs.setdefault(def_node.name) use_partition.inputs.setdefault(def_node.name)
if def_partition_name is not None: if def_partition_name is not None:
use_partition.partitions_dependent_on.setdefault(def_partition_name) use_partition.partitions_dependent_on.setdefault(def_partition_name)
def record_output( def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950
def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]
): # noqa: B950
def_partition_name = getattr(def_node, "_fx_partition", None) def_partition_name = getattr(def_node, "_fx_partition", None)
use_partition_name = getattr(use_node, "_fx_partition", None) use_partition_name = getattr(use_node, "_fx_partition", None)
if def_partition_name != use_partition_name: if def_partition_name != use_partition_name:
...@@ -291,7 +290,7 @@ def split_module( ...@@ -291,7 +290,7 @@ def split_module(
for partition_name in sorted_partitions: for partition_name in sorted_partitions:
partition = partitions[partition_name] partition = partitions[partition_name]
new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
return new_gm return new_gm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment