Unverified Commit 6c331a5a authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] refactored the autoparallel module for organization (#1706)

* [autoparallel] refactored the autoparallel module for organization

* polish code
parent 91cd34e6
import torch import torch
from colossalai.auto_parallel.solver.node_handler.broadcast import is_broadcastable, get_broadcast_shape, recover_sharding_spec_for_broadcast_shape
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.auto_parallel.tensor_shard.utils import (get_broadcast_shape, is_broadcastable,
recover_sharding_spec_for_broadcast_shape)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
def test_is_broadcastable(): def test_is_broadcastable():
......
import torch.nn as nn
import torch import torch
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
from colossalai.fx import ColoGraphModule, ColoTracer
class LinearModel(nn.Module): class LinearModel(nn.Module):
......
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.batch_norm_handler import BatchNormModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector BatchNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
def test_bn_module_handler(): def test_bn_module_handler():
......
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector BMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
......
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvModuleHandler, ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import (ConvFunctionHandler, ConvModuleHandler)
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
def test_conv_module_handler(): def test_conv_module_handler():
......
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.getitem_handler import GetItemHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler ConvFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import \
GetItemHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
class GetItemModel(nn.Module): class GetItemModel(nn.Module):
......
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.layer_norm_handler import LayerNormModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector LayerNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
def test_ln_module_handler(): def test_ln_module_handler():
......
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.dot_handler import LinearModuleHandler, LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import (LinearFunctionHandler, LinearModuleHandler)
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
......
from colossalai.fx.tracer.meta_patch.patched_module import linear import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import NormPoolingHandler from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
import pytest from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector OuputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
class OutputModel(nn.Module): class OutputModel(nn.Module):
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector PlacehodlerHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
class PlaceholderModel(nn.Module): class PlaceholderModel(nn.Module):
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
from colossalai.auto_parallel.solver.node_handler.reshape_handler import ReshapeHandler ConvFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import \
ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
class ReshapeModel(nn.Module): class ReshapeModel(nn.Module):
......
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.unary_elementwise_handler import UnaryElementwiseHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler ConvFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \
UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
class ReLuModel(nn.Module): class ReLuModel(nn.Module):
......
from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.where_handler import WhereHandler from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import \
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector WhereHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
class ConvModel(nn.Module): class ConvModel(nn.Module):
......
from functools import partial from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.fx import GraphModule
import torch.nn as nn import torch.nn as nn
import pytest from torch.fx import GraphModule
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
StrategiesConstructor)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (shape_consistency_pass,
from colossalai.auto_parallel.solver.solver import Solver_V2 solution_annotatation_pass)
from colossalai.auto_parallel.solver.options import SolverOptions from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
class ConvModel(nn.Module): class ConvModel(nn.Module):
...@@ -61,7 +59,7 @@ def check_apply(rank, world_size, port): ...@@ -61,7 +59,7 @@ def check_apply(rank, world_size, port):
cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph() cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm) graph_analyser = GraphAnalyser(gm)
solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser) solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args() ret = solver.call_solver_serialized_args()
solution = list(ret[0]) solution = list(ret[0])
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh() device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
......
import torch import torch
from torch.fx import GraphModule from torch.fx import GraphModule
import torch.nn as nn from torchvision.models import resnet50
import pytest
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
StrategiesConstructor)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from copy import deepcopy
from colossalai.auto_parallel.solver.solver import Solver
from torchvision.models import resnet34, resnet50
from colossalai.auto_parallel.solver.constants import *
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment