Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
...@@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHan ...@@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHan
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class ReshapeModel(nn.Module): class ReshapeModel(nn.Module):
...@@ -23,6 +23,7 @@ class ReshapeModel(nn.Module): ...@@ -23,6 +23,7 @@ class ReshapeModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_reshape_handler(): def test_reshape_handler():
model = ReshapeModel() model = ReshapeModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -16,9 +13,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat ...@@ -16,9 +13,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
NUM_EMBEDDINGS = 16 NUM_EMBEDDINGS = 16
...@@ -272,18 +268,14 @@ def check_embedding_function_handler(rank, world_size, port): ...@@ -272,18 +268,14 @@ def check_embedding_function_handler(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_embedding_module_handler(): def test_embedding_module_handler():
world_size = 4 spawn(check_embedding_module_handler, 4)
run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_embedding_function_handler(): def test_embedding_function_handler():
world_size = 4 spawn(check_embedding_function_handler, 4)
run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer ...@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import clear_cache_before_run
class GetattrModel(nn.Module): class GetattrModel(nn.Module):
...@@ -22,6 +23,7 @@ class GetattrModel(nn.Module): ...@@ -22,6 +23,7 @@ class GetattrModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_getattr_handler(): def test_getattr_handler():
model = GetattrModel() model = GetattrModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)
......
...@@ -2,7 +2,6 @@ from functools import partial ...@@ -2,7 +2,6 @@ from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -14,12 +13,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import Li ...@@ -14,12 +13,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import Li
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -103,12 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): ...@@ -103,12 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))]) # @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))])
@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) @parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))])
def test_getitem_from_tensor_handler(getitem_index): def test_getitem_from_tensor_handler(getitem_index):
world_size = 4 spawn(check_getitem_from_tensor_handler, 4)
run_func = partial(check_getitem_from_tensor_handler,
getitem_index=getitem_index,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
class GetItemFromTupleModel(nn.Module): class GetItemFromTupleModel(nn.Module):
...@@ -123,6 +115,7 @@ class GetItemFromTupleModel(nn.Module): ...@@ -123,6 +115,7 @@ class GetItemFromTupleModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_getitem_from_tuple_handler(): def test_getitem_from_tuple_handler():
model = GetItemFromTupleModel() model = GetItemFromTupleModel()
tracer = ColoTracer() tracer = ColoTracer()
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -11,12 +8,10 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer ...@@ -11,12 +8,10 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -104,9 +99,7 @@ def check_ln_module_handler(rank, world_size, port): ...@@ -104,9 +99,7 @@ def check_ln_module_handler(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_ln_module_handler(): def test_ln_module_handler():
world_size = 4 spawn(check_ln_module_handler, 4)
run_func = partial(check_ln_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -18,14 +15,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ...@@ -18,14 +15,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize from colossalai.testing.utils import parameterize
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_linear_module_handler(rank, bias, input_shape, world_size, port): def check_linear_module_handler(rank, world_size, port, bias, input_shape):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda() model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
...@@ -172,7 +168,7 @@ class LinearModel(nn.Module): ...@@ -172,7 +168,7 @@ class LinearModel(nn.Module):
return x return x
def check_linear_function_handler(rank, bias, input_shape, world_size, port): def check_linear_function_handler(rank, world_size, port, bias, input_shape):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModel().cuda() model = LinearModel().cuda()
...@@ -313,19 +309,18 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port): ...@@ -313,19 +309,18 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear_handler(input_shape, bias=False): def test_linear_handler(input_shape, bias=False):
world_size = 4 spawn(
run_func_module = partial(check_linear_module_handler, check_linear_module_handler,
bias=bias, 4,
input_shape=input_shape, bias=bias,
world_size=world_size, input_shape=input_shape,
port=free_port()) )
mp.spawn(run_func_module, nprocs=world_size) spawn(
run_func_function = partial(check_linear_function_handler, check_linear_function_handler,
bias=bias, 4,
input_shape=input_shape, bias=bias,
world_size=world_size, input_shape=input_shape,
port=free_port()) )
mp.spawn(run_func_function, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -18,7 +18,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ...@@ -18,7 +18,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector, StrategiesVector,
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.utils import parameterize from colossalai.testing.utils import clear_cache_before_run, parameterize
class MatMulModule(nn.Module): class MatMulModule(nn.Module):
...@@ -28,6 +28,7 @@ class MatMulModule(nn.Module): ...@@ -28,6 +28,7 @@ class MatMulModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize( @parameterize(
'tensor_shapes', 'tensor_shapes',
[ [
......
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -8,11 +7,11 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer ...@@ -8,11 +7,11 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.testing import clear_cache_before_run, run_on_environment_flag
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_norm_pool_handler(): def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)
......
...@@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer ...@@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import clear_cache_before_run, parameterize
class OutputModel(nn.Module): class OutputModel(nn.Module):
...@@ -23,7 +23,7 @@ class OutputModel(nn.Module): ...@@ -23,7 +23,7 @@ class OutputModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('output_option', ['distributed', 'replicated']) @parameterize('output_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use() @clear_cache_before_run()
def test_output_handler(output_option): def test_output_handler(output_option):
model = OutputModel() model = OutputModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)
......
...@@ -2,7 +2,6 @@ from functools import partial ...@@ -2,7 +2,6 @@ from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -15,9 +14,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat ...@@ -15,9 +14,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -55,7 +53,7 @@ class LinearReshapeModel(nn.Module): ...@@ -55,7 +53,7 @@ class LinearReshapeModel(nn.Module):
return permute_node return permute_node
def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port): def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if call_function == torch.permute: if call_function == torch.permute:
...@@ -328,14 +326,13 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, ...@@ -328,14 +326,13 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) @parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))])
@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel]) @parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel])
def test_view_handler(call_function, reshape_dims, model_cls): def test_view_handler(call_function, reshape_dims, model_cls):
world_size = 4 spawn(
run_func = partial(check_view_handler, check_view_handler,
call_function=call_function, 4,
reshape_dims=reshape_dims, call_function=call_function,
model_cls=model_cls, reshape_dims=reshape_dims,
world_size=world_size, model_cls=model_cls,
port=free_port()) )
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer ...@@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import clear_cache_before_run, parameterize
class PlaceholderModel(nn.Module): class PlaceholderModel(nn.Module):
...@@ -22,7 +22,7 @@ class PlaceholderModel(nn.Module): ...@@ -22,7 +22,7 @@ class PlaceholderModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('placeholder_option', ['distributed', 'replicated']) @parameterize('placeholder_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use() @clear_cache_before_run()
def test_placeholder_handler(placeholder_option): def test_placeholder_handler(placeholder_option):
model = PlaceholderModel() model = PlaceholderModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)
......
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -9,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHan ...@@ -9,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHan
from colossalai.auto_parallel.tensor_shard.options import ShardOption from colossalai.auto_parallel.tensor_shard.options import ShardOption
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class LinearModel(nn.Module): class LinearModel(nn.Module):
...@@ -108,6 +107,7 @@ def check_shard_option(shard_option): ...@@ -108,6 +107,7 @@ def check_shard_option(shard_option):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_shard_option(): def test_shard_option():
# for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]: # for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]:
for shard_option in [ShardOption.SHARD_LAST_AXIS]: for shard_option in [ShardOption.SHARD_LAST_AXIS]:
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat ...@@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -33,7 +28,7 @@ class LinearSplitModel(nn.Module): ...@@ -33,7 +28,7 @@ class LinearSplitModel(nn.Module):
return softmax_node return softmax_node
def check_split_handler(rank, softmax_dim, model_cls, world_size, port): def check_split_handler(rank, world_size, port, softmax_dim, model_cls):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = model_cls(softmax_dim=softmax_dim).cuda() model = model_cls(softmax_dim=softmax_dim).cuda()
...@@ -176,13 +171,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port): ...@@ -176,13 +171,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
@parameterize('softmax_dim', [0, 1, 2, 3]) @parameterize('softmax_dim', [0, 1, 2, 3])
@parameterize('model_cls', [LinearSplitModel]) @parameterize('model_cls', [LinearSplitModel])
def test_split_handler(softmax_dim, model_cls): def test_split_handler(softmax_dim, model_cls):
world_size = 4 spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls)
run_func = partial(check_split_handler,
softmax_dim=softmax_dim,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat ...@@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -47,7 +42,7 @@ class LinearSplitModel(nn.Module): ...@@ -47,7 +42,7 @@ class LinearSplitModel(nn.Module):
return split_node return split_node
def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port): def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = model_cls(split_size=split_size, split_dim=split_dim).cuda() model = model_cls(split_size=split_size, split_dim=split_dim).cuda()
...@@ -258,14 +253,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port ...@@ -258,14 +253,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
@parameterize('split_dim', [0, 1, 2]) @parameterize('split_dim', [0, 1, 2])
@parameterize('model_cls', [ConvSplitModel, LinearSplitModel]) @parameterize('model_cls', [ConvSplitModel, LinearSplitModel])
def test_split_handler(split_size, split_dim, model_cls): def test_split_handler(split_size, split_dim, model_cls):
world_size = 4 spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls)
run_func = partial(check_split_handler,
split_size=split_size,
split_dim=split_dim,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -14,9 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat ...@@ -14,9 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -36,7 +31,7 @@ class LinearSumModel(nn.Module): ...@@ -36,7 +31,7 @@ class LinearSumModel(nn.Module):
return sum_node return sum_node
def check_sum_handler(rank, sum_dims, keepdim, world_size, port): def check_sum_handler(rank, world_size, port, sum_dims, keepdim):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda() model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda()
...@@ -228,9 +223,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port): ...@@ -228,9 +223,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
@parameterize('sum_dims', [(0, 2), 1]) @parameterize('sum_dims', [(0, 2), 1])
@parameterize('keepdim', [False, True]) @parameterize('keepdim', [False, True])
def test_sum_handler(sum_dims, keepdim): def test_sum_handler(sum_dims, keepdim):
world_size = 4 spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim)
run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -7,7 +7,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer ...@@ -7,7 +7,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class TensorConstructorModel(nn.Module): class TensorConstructorModel(nn.Module):
...@@ -22,6 +22,7 @@ class TensorConstructorModel(nn.Module): ...@@ -22,6 +22,7 @@ class TensorConstructorModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_where_handler(): def test_where_handler():
model = TensorConstructorModel() model = TensorConstructorModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)
......
...@@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv ...@@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class ReLuModel(nn.Module): class ReLuModel(nn.Module):
...@@ -24,6 +24,7 @@ class ReLuModel(nn.Module): ...@@ -24,6 +24,7 @@ class ReLuModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_elementwise_handler(): def test_elementwise_handler():
model = ReLuModel() model = ReLuModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -15,9 +12,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat ...@@ -15,9 +12,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -255,13 +251,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): ...@@ -255,13 +251,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) @parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)])
@parameterize('model_cls', [ConvViewModel, LinearViewModel]) @parameterize('model_cls', [ConvViewModel, LinearViewModel])
def test_view_handler(tgt_shape, model_cls): def test_view_handler(tgt_shape, model_cls):
world_size = 4 spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls)
run_func = partial(check_view_handler,
tgt_shape=tgt_shape,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer ...@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import clear_cache_before_run
class ConvModel(nn.Module): class ConvModel(nn.Module):
...@@ -21,6 +22,7 @@ class ConvModel(nn.Module): ...@@ -21,6 +22,7 @@ class ConvModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_where_handler(): def test_where_handler():
model = ConvModel() model = ConvModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)
......
...@@ -10,10 +10,11 @@ from colossalai.auto_parallel.tensor_shard.options import SolverOptions ...@@ -10,10 +10,11 @@ from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing import clear_cache_before_run, run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_cost_graph(): def test_cost_graph():
physical_mesh_id = torch.arange(0, 8) physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4) mesh_shape = (2, 4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment