"docs/source/vscode:/vscode.git/clone" did not exist on "281b33f3628feb6bc7cb26faa42d260a169cb386"
Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
......@@ -8,7 +8,7 @@ import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
......
......@@ -9,7 +9,7 @@ from colossalai.autochunk.utils import flat_list
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
......
from functools import partial
from typing import Dict, List, Tuple
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerBlock
......@@ -15,6 +13,7 @@ except:
from test_autochunk_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_model():
......@@ -66,18 +65,19 @@ def get_chunk_target() -> Dict:
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
@clear_cache_before_run()
@parameterize("max_memory", [None, 20, 24])
@parameterize("data_args", [(32, 64)])
def test_evoformer_block(data_args, max_memory):
run_func = partial(
spawn(
run_test,
1,
data_args=data_args,
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
get_chunk_target=get_chunk_target,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
......
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import EvoformerStack
......@@ -15,6 +13,7 @@ except:
from test_autochunk_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_model():
......@@ -61,17 +60,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
@clear_cache_before_run()
@parameterize("max_memory", [None, 20, 24])
@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_evoformer_stack(data_args, max_memory):
run_func = partial(
spawn(
run_test,
1,
data_args=data_args,
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
......
from functools import partial
from typing import Dict, List, Tuple
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
try:
from fastfold.model.nn.evoformer import ExtraMSABlock
......@@ -14,6 +12,7 @@ except:
from test_autochunk_alphafold_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_model():
......@@ -57,17 +56,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
@clear_cache_before_run()
@parameterize("max_memory", [None, 20, 24])
@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len)
def test_extramsa_block(data_args, max_memory):
run_func = partial(
spawn(
run_test,
1,
data_args=data_args,
max_memory=max_memory,
get_model=get_model,
get_data=get_data,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
......
......@@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
......
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from diffusers import UNet2DModel
......@@ -16,6 +14,7 @@ except:
from test_autochunk_diffuser_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
BATCH_SIZE = 1
HEIGHT = 448
......@@ -37,17 +36,18 @@ def get_data(shape: tuple) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
@pytest.mark.parametrize("max_memory", [None, 150, 300])
@clear_cache_before_run()
@parameterize("model", MODELS)
@parameterize("shape", [LATENTS_SHAPE])
@parameterize("max_memory", [None, 150, 300])
def test_evoformer_block(model, shape, max_memory):
run_func = partial(
spawn(
run_test,
1,
max_memory=max_memory,
model=model,
data=get_data(shape),
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
......
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from transformers import GPT2Config, GPT2Model
......@@ -16,6 +14,7 @@ except:
from test_autochunk_transformer_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
BATCH_SIZE = 1
SEQ_LENGTH = 512
......@@ -35,18 +34,19 @@ def get_data(shape: tuple) -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
@pytest.mark.parametrize("max_memory", [None, 6, 8])
@clear_cache_before_run()
@parameterize("model", MODELS)
@parameterize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
@parameterize("max_memory", [None, 6, 8])
def test_autochunk_gpt(model, shape, max_memory):
run_func = partial(
spawn(
run_test,
1,
data=get_data(shape),
max_memory=max_memory,
model=model,
config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4),
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
......
......@@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
......
from functools import partial
from typing import List, Tuple
import pytest
import torch
import torch.multiprocessing as mp
try:
from timm.models.vision_transformer import vit_large_patch16_384 as vit
......@@ -16,6 +14,7 @@ except:
from test_autochunk_vit_utils import run_test
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.testing import clear_cache_before_run, parameterize, spawn
def get_data() -> Tuple[List, List]:
......@@ -28,16 +27,17 @@ def get_data() -> Tuple[List, List]:
not (AUTOCHUNK_AVAILABLE and HAS_REPO),
reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_memory", [None, 32, 40])
@clear_cache_before_run()
@parameterize("model", MODELS)
@parameterize("max_memory", [None, 32, 40])
def test_evoformer_block(model, max_memory):
run_func = partial(
spawn(
run_test,
1,
max_memory=max_memory,
model=model,
data=get_data(),
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
......
......@@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import free_port
if AUTOCHUNK_AVAILABLE:
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
......
from functools import partial
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.booster.accelerator import Accelerator
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.testing import clear_cache_before_run, parameterize
@clear_cache_before_run()
@parameterize('device', ['cpu', 'cuda'])
def run_accelerator(device):
def test_accelerator(device):
acceleartor = Accelerator(device)
model = nn.Linear(8, 8)
model = acceleartor.configure_model(model)
assert next(model.parameters()).device.type == device
del model, acceleartor
def run_dist(rank):
run_accelerator()
@rerun_if_address_is_in_use()
def test_accelerator():
world_size = 1
run_func = partial(run_dist)
mp.spawn(run_func, nprocs=world_size)
from functools import partial
import torch
import torch.multiprocessing as mp
from torch.optim import Adam
import colossalai
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
......@@ -41,6 +37,4 @@ def run_torch_amp(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_torch_ddp_plugin():
world_size = 1
run_func = partial(run_torch_amp, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_torch_amp, 1)
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
......@@ -119,9 +114,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port(), early_stop=early_stop)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, 2, early_stop=early_stop)
if __name__ == '__main__':
......
from functools import partial
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
......@@ -10,8 +7,7 @@ import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.interface import OptimizerWrapper
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
......@@ -103,6 +99,4 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_torch_ddp_plugin():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, 2)
......@@ -6,6 +6,7 @@ from torch.optim import Adam
from torchvision.models import resnet18
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.testing import clear_cache_before_run, parameterize
# ========
# Note:
......@@ -15,7 +16,8 @@ from colossalai.checkpoint_io import GeneralCheckpointIO
# ========
@pytest.mark.parametrize('use_safetensors', [True, False])
@clear_cache_before_run()
@parameterize('use_safetensors', [True, False])
def test_unsharded_checkpoint(use_safetensors: bool):
# create a model and optimizer
model = resnet18()
......
from functools import partial
import torch
import torch.multiprocessing as mp
from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import spawn
def check_device_mesh_manager(rank, world_size, port):
......@@ -31,9 +26,7 @@ def check_device_mesh_manager(rank, world_size, port):
def test_device_mesh_manager():
world_size = 4
run_func = partial(check_device_mesh_manager, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_device_mesh_manager, 4)
if __name__ == '__main__':
......
from functools import partial
from typing import List
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group
from colossalai.communication.p2p_v2 import _recv_object, _send_object
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
disable_existing_loggers()
world_size = 4
......@@ -45,9 +40,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_object_list_p2p():
disable_existing_loggers()
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_layer, world_size)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication import all_gather, all_reduce, reduce_scatter
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
......@@ -66,9 +64,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_comm():
world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_layer, 4)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication.p2p import send_forward, recv_forward, send_backward, recv_backward, send_forward_recv_backward, send_backward_recv_forward
from colossalai.communication.p2p import (
recv_backward,
recv_forward,
send_backward,
send_backward_recv_forward,
send_forward,
send_forward_recv_backward,
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(pipeline=2))
torch.manual_seed(123)
......@@ -96,9 +99,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_object_list_p2p():
world_size = 2
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_layer, 2)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment