Unverified Commit cdb7d5e7 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[hotfix] autoparallel unit test (#1752)

parent a4ce180e
from .operator_handler import OperatorHandler
from .dot_handler import DotHandler
from .conv_handler import ConvHandler
from .batch_norm_handler import BatchNormHandler from .batch_norm_handler import BatchNormHandler
from .reshape_handler import ReshapeHandler
from .bcast_op_handler import BcastOpHandler from .bcast_op_handler import BcastOpHandler
from .conv_handler import ConvHandler
from .dot_handler import DotHandler
from .embedding_handler import EmbeddingHandler from .embedding_handler import EmbeddingHandler
from .layer_norm_handler import LayerNormHandler
from .operator_handler import OperatorHandler
from .reshape_handler import ReshapeHandler
from .unary_elementwise_handler import UnaryElementwiseHandler from .unary_elementwise_handler import UnaryElementwiseHandler
from .where_handler import WhereHandler from .where_handler import WhereHandler
__all__ = [ __all__ = [
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler', 'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler' 'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler', 'LayerNormHandler'
] ]
from copy import deepcopy
import pytest
import torch import torch
from torch.fx import GraphModule
import torch.nn as nn import torch.nn as nn
import pytest from torch.fx import GraphModule
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.deprecated import Solver
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
from copy import deepcopy
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
...@@ -60,7 +61,7 @@ def test_solver(): ...@@ -60,7 +61,7 @@ def test_solver():
gm = GraphModule(model, graph, model.__class__.__name__) gm = GraphModule(model, graph, model.__class__.__name__)
solver_options = SolverOptions(fast=True) solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost() strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph = CostGraph(strategies_constructor.leaf_strategies)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment