Commit 93b788b9 authored by binmakeswell's avatar binmakeswell
Browse files

Merge branch 'main' into fix/format

parents 2fd528b9 1dc003c1
......@@ -198,54 +198,54 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
if model_cls.__name__ == 'LinearSplitModel':
if split_dim == 0:
assert '[R, R, R, S1]_0' in strategy_name_list
assert '[R, S0, R, S1]_1' in strategy_name_list
assert '[R, R, S0, S1]_2' in strategy_name_list
assert '[R, R, R, S0]_3' in strategy_name_list
assert '[R, S1, R, S0]_4' in strategy_name_list
assert '[R, R, S1, S0]_5' in strategy_name_list
assert '[R, R, R, R]_6' in strategy_name_list
assert '[R, S0, R, R]_7' in strategy_name_list
assert '[R, R, S0, R]_8' in strategy_name_list
assert '[R, R, R, R]_9' in strategy_name_list
assert '[R, S1, R, R]_10' in strategy_name_list
assert '[R, R, S1, R]_11' in strategy_name_list
assert '[R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1]_17' in strategy_name_list
assert '[R, R, R, R]_18' in strategy_name_list
assert '[R, S01, R, R]_19' in strategy_name_list
assert '[R, R, S01, R]_20' in strategy_name_list
assert '[R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01]_22' in strategy_name_list
assert '[R, R, R, S1]_11' in strategy_name_list
assert '[R, S0, R, S1]_12' in strategy_name_list
assert '[R, R, S0, S1]_13' in strategy_name_list
assert '[R, R, R, S0]_14' in strategy_name_list
assert '[R, S1, R, S0]_15' in strategy_name_list
assert '[R, R, S1, S0]_16' in strategy_name_list
assert '[R, R, R, R]_17' in strategy_name_list
assert '[R, S0, R, R]_18' in strategy_name_list
assert '[R, R, S0, R]_19' in strategy_name_list
assert '[R, R, R, R]_20' in strategy_name_list
assert '[R, S1, R, R]_21' in strategy_name_list
assert '[R, R, S1, R]_22' in strategy_name_list
assert '[R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, S1]_5' in strategy_name_list
assert '[R, R, R, R]_0' in strategy_name_list
assert '[R, S01, R, R]_1' in strategy_name_list
assert '[R, R, S01, R]_2' in strategy_name_list
assert '[R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01]_4' in strategy_name_list
if split_dim == 1:
assert '[S0, R, R, S1]_0' in strategy_name_list
assert '[R, R, R, S1]_1' in strategy_name_list
assert '[R, R, S0, S1]_2' in strategy_name_list
assert '[S1, R, R, S0]_3' in strategy_name_list
assert '[R, R, R, S0]_4' in strategy_name_list
assert '[R, R, S1, S0]_5' in strategy_name_list
assert '[S0, R, R, R]_6' in strategy_name_list
assert '[R, R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R]_8' in strategy_name_list
assert '[S1, R, R, R]_9' in strategy_name_list
assert '[R, R, R, R]_10' in strategy_name_list
assert '[R, R, S1, R]_11' in strategy_name_list
assert '[S0, R, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1]_17' in strategy_name_list
assert '[S01, R, R, R]_18' in strategy_name_list
assert '[R, R, R, R]_19' in strategy_name_list
assert '[R, R, S01, R]_20' in strategy_name_list
assert '[R, R, S0, S1]_13' in strategy_name_list
assert '[S1, R, R, S0]_14' in strategy_name_list
assert '[R, R, R, S0]_15' in strategy_name_list
assert '[R, R, S1, S0]_16' in strategy_name_list
assert '[S0, R, R, R]_17' in strategy_name_list
assert '[R, R, R, R]_18' in strategy_name_list
assert '[R, R, S0, R]_19' in strategy_name_list
assert '[S1, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01]_22' in strategy_name_list
assert '[R, R, S1, R]_22' in strategy_name_list
assert '[R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, S1]_5' in strategy_name_list
assert '[S01, R, R, R]_0' in strategy_name_list
assert '[R, R, R, R]_1' in strategy_name_list
assert '[R, R, S01, R]_2' in strategy_name_list
assert '[R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01]_4' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL')
......
......@@ -196,54 +196,57 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
if model_cls.__name__ == 'LinearViewModel':
if tgt_shape == (32, 4, 64, 16, 4):
assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_0' in strategy_name_list
assert '[R, S0, R, S1] -> FULLY REPLICATED_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_3' in strategy_name_list
assert '[R, S1, R, S0] -> FULLY REPLICATED_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R, R]_9' in strategy_name_list
assert '[R, S1, R, R] -> FULLY REPLICATED_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1, R]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0, R]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0, R]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1, R]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> FULLY REPLICATED_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01, R]_22' in strategy_name_list
for strategy in strategy_name_list:
print(strategy)
# print(strategy_name_list)
assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_11' in strategy_name_list
assert '[R, S0, R, S1] -> FULLY REPLICATED_12' in strategy_name_list
assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_13' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_14' in strategy_name_list
assert '[R, S1, R, S0] -> FULLY REPLICATED_15' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_16' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R, R]_17' in strategy_name_list
assert '[R, S0, R, R] -> FULLY REPLICATED_18' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R, R]_19' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R, R]_20' in strategy_name_list
assert '[R, S1, R, R] -> FULLY REPLICATED_21' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R, R]_22' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1, R]_10' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0, R]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0, R]_6' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1, R]_5' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R, R]_0' in strategy_name_list
assert '[R, S01, R, R] -> FULLY REPLICATED_1' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01, R]_4' in strategy_name_list
if tgt_shape == (8, 4, 4, 64, 16, 4):
assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_0' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_3' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_6' in strategy_name_list
assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_9' in strategy_name_list
assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_18' in strategy_name_list
assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_22' in strategy_name_list
assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_11' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_12' in strategy_name_list
assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_13' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_14' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_15' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_16' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_17' in strategy_name_list
assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_18' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_19' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_20' in strategy_name_list
assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_21' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_22' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_10' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_6' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_5' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_0' in strategy_name_list
assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_1' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_4' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL')
......
......@@ -6,7 +6,8 @@ from torch.fx import GraphModule
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
......
import torch
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
......
......@@ -3,13 +3,8 @@ from torch.fx import GraphModule
from torchvision.models import resnet50
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment