Unverified Commit 8283e95d authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] collated all deprecated files (#1700)

* [autoparallel] collated all deprecated files

* polish code
parent e2355d01
......@@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
import torch
import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.where_handler_v2 import WhereHandler
from colossalai.auto_parallel.solver.node_handler.where_handler import WhereHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
......
......@@ -7,10 +7,10 @@ from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor_V2
from colossalai.auto_parallel.solver.cost_graph import CostGraph_V2
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from copy import deepcopy
from colossalai.auto_parallel.solver.solver import Solver_V2
from colossalai.auto_parallel.solver.solver import Solver
from torchvision.models import resnet34, resnet50
from colossalai.auto_parallel.solver.constants import *
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
......@@ -60,12 +60,12 @@ def test_cost_graph():
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor_V2(graph, device_mesh, solver_options)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph_V2(strategies_constructor.leaf_strategies)
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
print(ret[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment