Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
......@@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHan
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class ReshapeModel(nn.Module):
......@@ -23,6 +23,7 @@ class ReshapeModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_reshape_handler():
model = ReshapeModel()
tracer = ColoTracer(bias_addition_split=True)
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -16,9 +13,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
NUM_EMBEDDINGS = 16
......@@ -272,18 +268,14 @@ def check_embedding_function_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_embedding_module_handler():
world_size = 4
run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_embedding_module_handler, 4)
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_embedding_function_handler():
world_size = 4
run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_embedding_function_handler, 4)
if __name__ == '__main__':
......
......@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import clear_cache_before_run
class GetattrModel(nn.Module):
......@@ -22,6 +23,7 @@ class GetattrModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_getattr_handler():
model = GetattrModel()
tracer = ColoTracer(bias_addition_split=True)
......
......@@ -2,7 +2,6 @@ from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -14,12 +13,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import Li
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
......@@ -103,12 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))])
@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))])
def test_getitem_from_tensor_handler(getitem_index):
world_size = 4
run_func = partial(check_getitem_from_tensor_handler,
getitem_index=getitem_index,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_getitem_from_tensor_handler, 4)
class GetItemFromTupleModel(nn.Module):
......@@ -123,6 +115,7 @@ class GetItemFromTupleModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_getitem_from_tuple_handler():
model = GetItemFromTupleModel()
tracer = ColoTracer()
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -11,12 +8,10 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
......@@ -104,9 +99,7 @@ def check_ln_module_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_ln_module_handler():
world_size = 4
run_func = partial(check_ln_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_ln_module_handler, 4)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -18,14 +15,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_linear_module_handler(rank, bias, input_shape, world_size, port):
def check_linear_module_handler(rank, world_size, port, bias, input_shape):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
......@@ -172,7 +168,7 @@ class LinearModel(nn.Module):
return x
def check_linear_function_handler(rank, bias, input_shape, world_size, port):
def check_linear_function_handler(rank, world_size, port, bias, input_shape):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModel().cuda()
......@@ -313,19 +309,18 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler(input_shape, bias=False):
world_size = 4
run_func_module = partial(check_linear_module_handler,
spawn(
check_linear_module_handler,
4,
bias=bias,
input_shape=input_shape,
world_size=world_size,
port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
run_func_function = partial(check_linear_function_handler,
)
spawn(
check_linear_function_handler,
4,
bias=bias,
input_shape=input_shape,
world_size=world_size,
port=free_port())
mp.spawn(run_func_function, nprocs=world_size)
)
if __name__ == '__main__':
......
......@@ -18,7 +18,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.utils import parameterize
from colossalai.testing.utils import clear_cache_before_run, parameterize
class MatMulModule(nn.Module):
......@@ -28,6 +28,7 @@ class MatMulModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize(
'tensor_shapes',
[
......
import pytest
import torch
import torch.nn as nn
......@@ -8,11 +7,11 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer(bias_addition_split=True)
......
......@@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import clear_cache_before_run, parameterize
class OutputModel(nn.Module):
......@@ -23,7 +23,7 @@ class OutputModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('output_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_output_handler(output_option):
model = OutputModel()
tracer = ColoTracer(bias_addition_split=True)
......
......@@ -2,7 +2,6 @@ from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -15,9 +14,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
......@@ -55,7 +53,7 @@ class LinearReshapeModel(nn.Module):
return permute_node
def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port):
def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if call_function == torch.permute:
......@@ -328,14 +326,13 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))])
@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel])
def test_view_handler(call_function, reshape_dims, model_cls):
world_size = 4
run_func = partial(check_view_handler,
spawn(
check_view_handler,
4,
call_function=call_function,
reshape_dims=reshape_dims,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
)
if __name__ == '__main__':
......
......@@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import clear_cache_before_run, parameterize
class PlaceholderModel(nn.Module):
......@@ -22,7 +22,7 @@ class PlaceholderModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('placeholder_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_placeholder_handler(placeholder_option):
model = PlaceholderModel()
tracer = ColoTracer(bias_addition_split=True)
......
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -9,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHan
from colossalai.auto_parallel.tensor_shard.options import ShardOption
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class LinearModel(nn.Module):
......@@ -108,6 +107,7 @@ def check_shard_option(shard_option):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_shard_option():
# for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]:
for shard_option in [ShardOption.SHARD_LAST_AXIS]:
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
......@@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
......@@ -33,7 +28,7 @@ class LinearSplitModel(nn.Module):
return softmax_node
def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
def check_split_handler(rank, world_size, port, softmax_dim, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = model_cls(softmax_dim=softmax_dim).cuda()
......@@ -176,13 +171,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
@parameterize('softmax_dim', [0, 1, 2, 3])
@parameterize('model_cls', [LinearSplitModel])
def test_split_handler(softmax_dim, model_cls):
world_size = 4
run_func = partial(check_split_handler,
softmax_dim=softmax_dim,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
......@@ -47,7 +42,7 @@ class LinearSplitModel(nn.Module):
return split_node
def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port):
def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = model_cls(split_size=split_size, split_dim=split_dim).cuda()
......@@ -258,14 +253,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
@parameterize('split_dim', [0, 1, 2])
@parameterize('model_cls', [ConvSplitModel, LinearSplitModel])
def test_split_handler(split_size, split_dim, model_cls):
world_size = 4
run_func = partial(check_split_handler,
split_size=split_size,
split_dim=split_dim,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -14,9 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
......@@ -36,7 +31,7 @@ class LinearSumModel(nn.Module):
return sum_node
def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
def check_sum_handler(rank, world_size, port, sum_dims, keepdim):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda()
......@@ -228,9 +223,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
@parameterize('sum_dims', [(0, 2), 1])
@parameterize('keepdim', [False, True])
def test_sum_handler(sum_dims, keepdim):
world_size = 4
run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim)
if __name__ == '__main__':
......
......@@ -7,7 +7,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class TensorConstructorModel(nn.Module):
......@@ -22,6 +22,7 @@ class TensorConstructorModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_where_handler():
model = TensorConstructorModel()
tracer = ColoTracer(bias_addition_split=True)
......
......@@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class ReLuModel(nn.Module):
......@@ -24,6 +24,7 @@ class ReLuModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_elementwise_handler():
model = ReLuModel()
tracer = ColoTracer(bias_addition_split=True)
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -15,9 +12,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
......@@ -255,13 +251,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)])
@parameterize('model_cls', [ConvViewModel, LinearViewModel])
def test_view_handler(tgt_shape, model_cls):
world_size = 4
run_func = partial(check_view_handler,
tgt_shape=tgt_shape,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls)
if __name__ == '__main__':
......
......@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import clear_cache_before_run
class ConvModel(nn.Module):
......@@ -21,6 +22,7 @@ class ConvModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_where_handler():
model = ConvModel()
tracer = ColoTracer(bias_addition_split=True)
......
......@@ -10,10 +10,11 @@ from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment