Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
import torch
import torch.nn as nn
import transformers
from torch.fx import GraphModule
......@@ -7,10 +6,10 @@ from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing import parameterize
from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
......@@ -20,6 +19,7 @@ HIDDEN_DIM = 384
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)
......
......@@ -6,6 +6,8 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import clear_cache_before_run
class LinearModel(nn.Module):
......@@ -26,6 +28,7 @@ class LinearModel(nn.Module):
@pytest.mark.skip('meta tensor has some bugs in 1.11')
@clear_cache_before_run()
def test_liveness_analysis():
model = LinearModel()
tracer = ColoTracer(bias_addition_split=True)
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.meta_profiler import meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize('func', [
torch.nn.functional.softmax,
torch.nn.functional.relu,
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh
......@@ -10,8 +7,7 @@ from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
......@@ -62,9 +58,7 @@ def _binary_elementwise_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_binary_elementwise_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_binary_elementwise_mem_test, 4)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
......@@ -25,7 +20,7 @@ class ConvFunctionModule(nn.Module):
return nn.functional.conv2d(input, self.conv_weight)
def _conv_module_mem_test(rank, bias, world_size, port):
def _conv_module_mem_test(rank, world_size, port, bias):
"""This function is for conv memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
......@@ -62,9 +57,7 @@ def _conv_module_mem_test(rank, bias, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_conv_meta_concrete_info_match(bias=False):
world_size = 4
run_func_module = partial(_conv_module_mem_test, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_conv_module_mem_test, 4, bias=bias)
def _conv_function_mem_test(rank, world_size, port):
......@@ -103,9 +96,7 @@ def _conv_function_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_conv_function_concrete_info_match():
world_size = 4
run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_conv_function_mem_test, 4)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.testing.utils import clear_cache_before_run
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
from colossalai.auto_parallel.meta_profiler import meta_register
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_embedding_meta_info():
meta_func = meta_register.get(torch.nn.Embedding)
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
class MyModule(nn.Module):
......@@ -63,9 +53,7 @@ def _linear_module_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_module_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_linear_module_mem_test, 4)
def _linear_function_mem_test(rank, world_size, port):
......@@ -101,9 +89,7 @@ def _linear_function_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_function_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_linear_function_mem_test, 4)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
......@@ -28,6 +10,7 @@ if torch.__version__ >= '1.12.0':
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize(
'tensor_shapes',
[
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
from colossalai.auto_parallel.meta_profiler import meta_register
def _batchnorm_module_mem_test(rank, world_size, port):
......@@ -62,9 +50,7 @@ def _batchnorm_module_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_batchnorm_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_batchnorm_module_mem_test, 4)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations')
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
......@@ -51,9 +46,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_adaptiveavgpool_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_adaptiveavgpool_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_adaptiveavgpool_module_mem_test, 4)
def _maxpool_module_mem_test(rank, world_size, port):
......@@ -92,9 +85,7 @@ def _maxpool_module_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_maxpool_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_maxpool_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_maxpool_module_mem_test, 4)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.testing.utils import clear_cache_before_run
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
......@@ -37,6 +20,7 @@ class SplitModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_tensor_meta_info():
"""test tensor related meta information
We will just use torch.Tensor.split for the test
......
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.testing.utils import clear_cache_before_run
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
......@@ -26,6 +10,7 @@ if torch.__version__ >= '1.12.0':
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_where_meta_info():
meta_func = meta_register.get(torch.where)
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
......@@ -11,9 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
......@@ -45,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module):
return output
def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = module(using_kwargs).cuda()
......@@ -249,14 +244,13 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por
@parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use()
def test_2d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4
run_func = partial(check_2d_device_mesh,
module=module,
bias_shape=bias_shape,
world_size=world_size,
using_kwargs=using_kwargs,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(
check_2d_device_mesh,
4,
module=module,
bias_shape=bias_shape,
using_kwargs=using_kwargs,
)
@pytest.mark.skip("skip due to bias cases not ready")
......@@ -267,14 +261,13 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs):
@parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use()
def test_1d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4
run_func = partial(check_1d_device_mesh,
module=module,
bias_shape=bias_shape,
using_kwargs=using_kwargs,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(
check_1d_device_mesh,
4,
module=module,
bias_shape=bias_shape,
using_kwargs=using_kwargs,
)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -17,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
......@@ -45,7 +40,7 @@ class AddmmModel_with_param(nn.Module):
return x
def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port):
def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if model_cls == AddmmModel:
......@@ -189,13 +184,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
@parameterize('model_cls', [AddmmModel, AddmmModel_with_param])
@rerun_if_address_is_in_use()
def test_addmm_handler(input_shape, model_cls):
world_size = 4
run_func_function = partial(check_addmm_function_handler,
input_shape=input_shape,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func_function, nprocs=world_size)
spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
......@@ -114,9 +109,7 @@ def check_bn_module_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bn_module_handler():
world_size = 4
run_func = partial(check_bn_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_bn_module_handler, 4)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -19,9 +15,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
WEIGHT_SHAPE = (32, 16)
......@@ -168,9 +162,7 @@ def check_linear_module_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler():
world_size = 4
run_func_module = partial(check_linear_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(check_linear_module_handler)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
......@@ -18,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
......@@ -35,7 +29,7 @@ class LinearModule(torch.nn.Module):
return x
def check_linear_module_handler(rank, bias, world_size, port):
def check_linear_module_handler(rank, world_size, port, bias):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModule(16, 32, bias=bias).cuda()
......@@ -157,9 +151,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler(bias=True):
world_size = 4
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(check_linear_module_handler, bias=bias)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size, port):
def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
......@@ -149,7 +144,7 @@ class BEOpModelWithIntConst(nn.Module):
return out
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port):
def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
......@@ -236,13 +231,12 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_tensor(op, other_dim):
world_size = 4
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
op=op,
other_dim=other_dim,
world_size=world_size,
port=free_port())
mp.spawn(run_func_tensor, nprocs=world_size)
spawn(
check_binary_elementwise_handler_with_tensor,
4,
op=op,
other_dim=other_dim,
)
@run_on_environment_flag(name='AUTO_PARALLEL')
......@@ -252,14 +246,13 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
world_size = 4
run_func_int = partial(check_binary_elementwise_handler_with_int,
op=op,
model_cls=model_cls,
other_dim=other_dim,
world_size=world_size,
port=free_port())
mp.spawn(run_func_int, nprocs=world_size)
spawn(
check_binary_elementwise_handler_with_int,
4,
op=op,
model_cls=model_cls,
other_dim=other_dim,
)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
......@@ -207,11 +202,8 @@ def check_1d_device_mesh(rank, module, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bmm_handler(module):
world_size = 4
run_func_2d = partial(check_2d_device_mesh, module=module, world_size=world_size, port=free_port())
mp.spawn(run_func_2d, nprocs=world_size)
run_func_1d = partial(check_1d_device_mesh, module=module, world_size=world_size, port=free_port())
mp.spawn(run_func_1d, nprocs=world_size)
spawn(check_2d_device_mesh, 4, module=module)
spawn(check_1d_device_mesh, 4, module=module)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
......@@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_conv_module_handler(rank, bias, world_size, port):
def check_conv_module_handler(rank, world_size, port, bias):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda()
......@@ -155,7 +150,7 @@ class ConvModel(nn.Module):
return x
def check_conv_function_handler(rank, bias, world_size, port):
def check_conv_function_handler(rank, world_size, port, bias):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = ConvModel().cuda()
......@@ -302,9 +297,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
# @parameterize('bias', [True, False])
@rerun_if_address_is_in_use()
def test_conv_module_handler(bias=False):
world_size = 4
run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_conv_module_handler, 4, bias=bias)
@run_on_environment_flag(name='AUTO_PARALLEL')
......@@ -314,9 +307,7 @@ def test_conv_module_handler(bias=False):
# @parameterize('bias', [True, False])
@rerun_if_address_is_in_use()
def test_conv_function_handler(bias=False):
world_size = 4
run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_conv_function_handler, 4, bias=bias)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment