Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
...@@ -8,7 +8,7 @@ import colossalai ...@@ -8,7 +8,7 @@ import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE: if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
......
...@@ -9,7 +9,7 @@ from colossalai.autochunk.utils import flat_list ...@@ -9,7 +9,7 @@ from colossalai.autochunk.utils import flat_list
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE: if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
......
from functools import partial
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import pytest import pytest
import torch import torch
import torch.fx import torch.fx
import torch.multiprocessing as mp
try: try:
from fastfold.model.nn.evoformer import EvoformerBlock from fastfold.model.nn.evoformer import EvoformerBlock
...@@ -15,6 +13,7 @@ except: ...@@ -15,6 +13,7 @@ except:
from test_autochunk_alphafold_utils import run_test from test_autochunk_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_model(): def get_model():
...@@ -66,18 +65,19 @@ def get_chunk_target() -> Dict: ...@@ -66,18 +65,19 @@ def get_chunk_target() -> Dict:
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("max_memory", [None, 20, 24]) @clear_cache_before_run()
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) @parameterize("max_memory", [None, 20, 24])
@parameterize("data_args", [(32, 64)])
def test_evoformer_block(data_args, max_memory): def test_evoformer_block(data_args, max_memory):
run_func = partial( spawn(
run_test, run_test,
1,
data_args=data_args, data_args=data_args,
max_memory=max_memory, max_memory=max_memory,
get_model=get_model, get_model=get_model,
get_data=get_data, get_data=get_data,
get_chunk_target=get_chunk_target, get_chunk_target=get_chunk_target,
) )
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":
......
from functools import partial
from typing import List, Tuple from typing import List, Tuple
import pytest import pytest
import torch import torch
import torch.fx import torch.fx
import torch.multiprocessing as mp
try: try:
from fastfold.model.nn.evoformer import EvoformerStack from fastfold.model.nn.evoformer import EvoformerStack
...@@ -15,6 +13,7 @@ except: ...@@ -15,6 +13,7 @@ except:
from test_autochunk_alphafold_utils import run_test from test_autochunk_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_model(): def get_model():
...@@ -61,17 +60,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: ...@@ -61,17 +60,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("max_memory", [None, 20, 24]) @clear_cache_before_run()
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) @parameterize("max_memory", [None, 20, 24])
@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_evoformer_stack(data_args, max_memory): def test_evoformer_stack(data_args, max_memory):
run_func = partial( spawn(
run_test, run_test,
1,
data_args=data_args, data_args=data_args,
max_memory=max_memory, max_memory=max_memory,
get_model=get_model, get_model=get_model,
get_data=get_data, get_data=get_data,
) )
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":
......
from functools import partial
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import pytest import pytest
import torch import torch
import torch.fx import torch.fx
import torch.multiprocessing as mp
try: try:
from fastfold.model.nn.evoformer import ExtraMSABlock from fastfold.model.nn.evoformer import ExtraMSABlock
...@@ -14,6 +12,7 @@ except: ...@@ -14,6 +12,7 @@ except:
from test_autochunk_alphafold_utils import run_test from test_autochunk_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_model(): def get_model():
...@@ -57,17 +56,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: ...@@ -57,17 +56,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("max_memory", [None, 20, 24]) @clear_cache_before_run()
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) @parameterize("max_memory", [None, 20, 24])
@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_extramsa_block(data_args, max_memory): def test_extramsa_block(data_args, max_memory):
run_func = partial( spawn(
run_test, run_test,
1,
data_args=data_args, data_args=data_args,
max_memory=max_memory, max_memory=max_memory,
get_model=get_model, get_model=get_model,
get_data=get_data, get_data=get_data,
) )
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE ...@@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE: if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
......
from functools import partial
from typing import List, Tuple from typing import List, Tuple
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
try: try:
from diffusers import UNet2DModel from diffusers import UNet2DModel
...@@ -16,6 +14,7 @@ except: ...@@ -16,6 +14,7 @@ except:
from test_autochunk_diffuser_utils import run_test from test_autochunk_diffuser_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
BATCH_SIZE = 1 BATCH_SIZE = 1
HEIGHT = 448 HEIGHT = 448
...@@ -37,17 +36,18 @@ def get_data(shape: tuple) -> Tuple[List, List]: ...@@ -37,17 +36,18 @@ def get_data(shape: tuple) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("model", MODELS) @clear_cache_before_run()
@pytest.mark.parametrize("shape", [LATENTS_SHAPE]) @parameterize("model", MODELS)
@pytest.mark.parametrize("max_memory", [None, 150, 300]) @parameterize("shape", [LATENTS_SHAPE])
@parameterize("max_memory", [None, 150, 300])
def test_evoformer_block(model, shape, max_memory): def test_evoformer_block(model, shape, max_memory):
run_func = partial( spawn(
run_test, run_test,
1,
max_memory=max_memory, max_memory=max_memory,
model=model, model=model,
data=get_data(shape), data=get_data(shape),
) )
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":
......
from functools import partial
from typing import List, Tuple from typing import List, Tuple
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
try: try:
from transformers import GPT2Config, GPT2Model from transformers import GPT2Config, GPT2Model
...@@ -16,6 +14,7 @@ except: ...@@ -16,6 +14,7 @@ except:
from test_autochunk_transformer_utils import run_test from test_autochunk_transformer_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGTH = 512 SEQ_LENGTH = 512
...@@ -35,18 +34,19 @@ def get_data(shape: tuple) -> Tuple[List, List]: ...@@ -35,18 +34,19 @@ def get_data(shape: tuple) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("model", MODELS) @clear_cache_before_run()
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)]) @parameterize("model", MODELS)
@pytest.mark.parametrize("max_memory", [None, 6, 8]) @parameterize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
@parameterize("max_memory", [None, 6, 8])
def test_autochunk_gpt(model, shape, max_memory): def test_autochunk_gpt(model, shape, max_memory):
run_func = partial( spawn(
run_test, run_test,
1,
data=get_data(shape), data=get_data(shape),
max_memory=max_memory, max_memory=max_memory,
model=model, model=model,
config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4), config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4),
) )
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE ...@@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE: if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
......
from functools import partial
from typing import List, Tuple from typing import List, Tuple
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
try: try:
from timm.models.vision_transformer import vit_large_patch16_384 as vit from timm.models.vision_transformer import vit_large_patch16_384 as vit
...@@ -16,6 +14,7 @@ except: ...@@ -16,6 +14,7 @@ except:
from test_autochunk_vit_utils import run_test from test_autochunk_vit_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_data() -> Tuple[List, List]: def get_data() -> Tuple[List, List]:
...@@ -28,16 +27,17 @@ def get_data() -> Tuple[List, List]: ...@@ -28,16 +27,17 @@ def get_data() -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO), not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0", reason="torch version is lower than 1.12.0",
) )
@pytest.mark.parametrize("model", MODELS) @clear_cache_before_run()
@pytest.mark.parametrize("max_memory", [None, 32, 40]) @parameterize("model", MODELS)
@parameterize("max_memory", [None, 32, 40])
def test_evoformer_block(model, max_memory): def test_evoformer_block(model, max_memory):
run_func = partial( spawn(
run_test, run_test,
1,
max_memory=max_memory, max_memory=max_memory,
model=model, model=model,
data=get_data(), data=get_data(),
) )
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE ...@@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE: if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
......
from functools import partial
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.booster.accelerator import Accelerator from colossalai.booster.accelerator import Accelerator
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import clear_cache_before_run, parameterize
@clear_cache_before_run()
@parameterize('device', ['cpu', 'cuda']) @parameterize('device', ['cpu', 'cuda'])
def run_accelerator(device): def test_accelerator(device):
acceleartor = Accelerator(device) acceleartor = Accelerator(device)
model = nn.Linear(8, 8) model = nn.Linear(8, 8)
model = acceleartor.configure_model(model) model = acceleartor.configure_model(model)
assert next(model.parameters()).device.type == device assert next(model.parameters()).device.type == device
del model, acceleartor del model, acceleartor
def run_dist(rank):
run_accelerator()
@rerun_if_address_is_in_use()
def test_accelerator():
world_size = 1
run_func = partial(run_dist)
mp.spawn(run_func, nprocs=world_size)
from functools import partial
import torch import torch
import torch.multiprocessing as mp
from torch.optim import Adam from torch.optim import Adam
import colossalai import colossalai
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
...@@ -41,6 +37,4 @@ def run_torch_amp(rank, world_size, port): ...@@ -41,6 +37,4 @@ def run_torch_amp(rank, world_size, port):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_torch_ddp_plugin(): def test_torch_ddp_plugin():
world_size = 1 spawn(run_torch_amp, 1)
run_func = partial(run_torch_amp, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
from functools import partial
import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin from colossalai.booster.plugin import GeminiPlugin
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
...@@ -119,9 +114,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True): ...@@ -119,9 +114,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True): def test_gemini_plugin(early_stop: bool = True):
world_size = 2 spawn(run_dist, 2, early_stop=early_stop)
run_func = partial(run_dist, world_size=world_size, port=free_port(), early_stop=early_stop)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD from torch.optim import SGD
...@@ -10,8 +7,7 @@ import colossalai ...@@ -10,8 +7,7 @@ import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
...@@ -103,6 +99,4 @@ def run_dist(rank, world_size, port): ...@@ -103,6 +99,4 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_torch_ddp_plugin(): def test_torch_ddp_plugin():
world_size = 2 spawn(run_dist, 2)
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
...@@ -6,6 +6,7 @@ from torch.optim import Adam ...@@ -6,6 +6,7 @@ from torch.optim import Adam
from torchvision.models import resnet18 from torchvision.models import resnet18
from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.testing import clear_cache_before_run, parameterize
# ======== # ========
# Note: # Note:
...@@ -15,7 +16,8 @@ from colossalai.checkpoint_io import GeneralCheckpointIO ...@@ -15,7 +16,8 @@ from colossalai.checkpoint_io import GeneralCheckpointIO
# ======== # ========
@pytest.mark.parametrize('use_safetensors', [True, False]) @clear_cache_before_run()
@parameterize('use_safetensors', [True, False])
def test_unsharded_checkpoint(use_safetensors: bool): def test_unsharded_checkpoint(use_safetensors: bool):
# create a model and optimizer # create a model and optimizer
model = resnet18() model = resnet18()
......
from functools import partial
import torch import torch
import torch.multiprocessing as mp
from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer import ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.testing import spawn
def check_device_mesh_manager(rank, world_size, port): def check_device_mesh_manager(rank, world_size, port):
...@@ -31,9 +26,7 @@ def check_device_mesh_manager(rank, world_size, port): ...@@ -31,9 +26,7 @@ def check_device_mesh_manager(rank, world_size, port):
def test_device_mesh_manager(): def test_device_mesh_manager():
world_size = 4 spawn(check_device_mesh_manager, 4)
run_func = partial(check_device_mesh_manager, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
from typing import List
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp from colossalai.communication.p2p_v2 import _recv_object, _send_object
from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
disable_existing_loggers() disable_existing_loggers()
world_size = 4 world_size = 4
...@@ -45,9 +40,7 @@ def check_layer(rank, world_size, port): ...@@ -45,9 +40,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_object_list_p2p(): def test_object_list_p2p():
disable_existing_loggers() spawn(check_layer, world_size)
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication import all_gather, all_reduce, reduce_scatter from colossalai.communication import all_gather, all_reduce, reduce_scatter
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import get_current_device
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
...@@ -66,9 +64,7 @@ def check_layer(rank, world_size, port): ...@@ -66,9 +64,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_comm(): def test_comm():
world_size = 4 spawn(check_layer, 4)
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp from colossalai.communication.p2p import (
from colossalai.communication.p2p import send_forward, recv_forward, send_backward, recv_backward, send_forward_recv_backward, send_backward_recv_forward recv_backward,
recv_forward,
send_backward,
send_backward_recv_forward,
send_forward,
send_forward_recv_backward,
)
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import rerun_if_address_is_in_use
CONFIG = dict(parallel=dict(pipeline=2)) CONFIG = dict(parallel=dict(pipeline=2))
torch.manual_seed(123) torch.manual_seed(123)
...@@ -96,9 +99,7 @@ def check_layer(rank, world_size, port): ...@@ -96,9 +99,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_object_list_p2p(): def test_object_list_p2p():
world_size = 2 spawn(check_layer, 2)
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment