Unverified Commit 5a1a095b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactored with the new rerun decorator (#763)

* [test] refactored with the new rerun decorator

* polish test case
parent deaf99f4
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param import ShardedTensor
import colossalai import colossalai
...@@ -35,7 +35,7 @@ def run_tensor_move(rank): ...@@ -35,7 +35,7 @@ def run_tensor_move(rank):
assert (tgt_t.device.type == 'cpu') assert (tgt_t.device.type == 'cpu')
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_tensor_move(): def test_tensor_move():
mp.spawn(run_tensor_move, nprocs=1) mp.spawn(run_tensor_move, nprocs=1)
......
...@@ -3,6 +3,7 @@ from functools import partial ...@@ -3,6 +3,7 @@ from functools import partial
from pathlib import Path from pathlib import Path
import colossalai import colossalai
from colossalai.testing.utils import rerun_if_address_is_in_use
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -10,7 +11,7 @@ import torch.nn as nn ...@@ -10,7 +11,7 @@ import torch.nn as nn
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import free_port, get_dataloader from colossalai.utils import free_port, get_dataloader
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from torch.optim import Adam from torch.optim import Adam
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
...@@ -87,7 +88,7 @@ def run_no_pipeline(rank, world_size, port): ...@@ -87,7 +88,7 @@ def run_no_pipeline(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_engine(): def test_engine():
world_size = 4 world_size = 4
func = partial(run_no_pipeline, world_size=world_size, port=free_port()) func = partial(run_no_pipeline, world_size=world_size, port=free_port())
......
...@@ -16,7 +16,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -16,7 +16,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from functools import partial from functools import partial
from colossalai.testing import parameterize, rerun_on_exception from colossalai.testing import parameterize, rerun_if_address_is_in_use
def checkpoint_wrapper(module, enable=True): def checkpoint_wrapper(module, enable=True):
...@@ -102,7 +102,7 @@ def run_dist(rank, world_size, port): ...@@ -102,7 +102,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_zero_clip_grad(): def test_zero_clip_grad():
world_size = 4 world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_on_exception from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import BucketTensorShardStrategy from colossalai.zero.shard_utils import BucketTensorShardStrategy
...@@ -62,7 +62,7 @@ def _run_dist(rank, world_size, port): ...@@ -62,7 +62,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False # use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_found_inf(world_size): def test_found_inf(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port()) run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -8,7 +8,7 @@ import pytest ...@@ -8,7 +8,7 @@ import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_on_exception from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.utils.memory_tracer.model_data_memtracer import \
...@@ -64,7 +64,7 @@ def run_dist(rank, world_size, port): ...@@ -64,7 +64,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4]) @pytest.mark.parametrize("world_size", [1, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_zero_init_context(world_size): def test_zero_init_context(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -10,7 +10,7 @@ from colossalai.zero.init_ctx import ZeroInitContext ...@@ -10,7 +10,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.shard_utils import BucketTensorShardStrategy from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from functools import partial from functools import partial
...@@ -64,7 +64,7 @@ def run_dist(rank, world_size, port): ...@@ -64,7 +64,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_mem_collector(world_size=2): def test_mem_collector(world_size=2):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -7,7 +7,7 @@ import colossalai ...@@ -7,7 +7,7 @@ import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_on_exception from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
...@@ -59,7 +59,7 @@ def run_dist(rank, world_size, port): ...@@ -59,7 +59,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_shard_model_v2(world_size): def test_shard_model_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -5,12 +5,11 @@ import colossalai ...@@ -5,12 +5,11 @@ import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import parameterize from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.testing import rerun_on_exception
from tests.test_zero.common import CONFIG, allclose from tests.test_zero.common import CONFIG, allclose
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
...@@ -37,7 +36,7 @@ def _run_shard_tensor(rank, world_size, port): ...@@ -37,7 +36,7 @@ def _run_shard_tensor(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_shard_tensor(world_size): def test_shard_tensor(world_size):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port()) run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
...@@ -85,7 +84,7 @@ def _run_shard_param_v2(rank, world_size, port): ...@@ -85,7 +84,7 @@ def _run_shard_param_v2(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_shard_param_v2(world_size): def test_shard_param_v2(world_size):
run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port()) run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -8,7 +8,7 @@ import torch.distributed as dist ...@@ -8,7 +8,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import CPUAdam from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize, rerun_on_exception from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
...@@ -105,7 +105,7 @@ def _run_dist(rank, world_size, port): ...@@ -105,7 +105,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False # use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_sharded_optim_v2(world_size): def test_sharded_optim_v2(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port()) run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -10,7 +10,7 @@ import torch.distributed as dist ...@@ -10,7 +10,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import TensorShardStrategy
...@@ -71,7 +71,7 @@ def run_dist(rank, world_size, port): ...@@ -71,7 +71,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_sharded_optim_with_sync_bn(): def test_sharded_optim_with_sync_bn():
""" """
This test is to make sure that buffers are synchronized between ranks This test is to make sure that buffers are synchronized between ranks
......
...@@ -8,7 +8,7 @@ import colossalai ...@@ -8,7 +8,7 @@ import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_on_exception from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
...@@ -49,7 +49,7 @@ def run_dist(rank, world_size, port): ...@@ -49,7 +49,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_zero_state_dict(world_size): def test_zero_state_dict(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -10,7 +10,7 @@ from colossalai.gemini import StatefulTensorMgr ...@@ -10,7 +10,7 @@ from colossalai.gemini import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from typing import List from typing import List
from functools import partial from functools import partial
...@@ -120,8 +120,8 @@ def run_dist(rank, world_size, port): ...@@ -120,8 +120,8 @@ def run_dist(rank, world_size, port):
run_stm() run_stm()
@pytest.mark.gpu @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_stateful_tensor_manager(world_size=1): def test_stateful_tensor_manager(world_size=1):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -6,6 +6,7 @@ from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage ...@@ -6,6 +6,7 @@ from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
colo_model_tensor_clone) colo_model_tensor_clone)
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
import torch import torch
...@@ -84,6 +85,7 @@ def run_dist(rank, world_size, port): ...@@ -84,6 +85,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [4, 5]) @pytest.mark.parametrize("world_size", [4, 5])
@rerun_if_address_is_in_use()
def test_zero_tensor_utils(world_size): def test_zero_tensor_utils(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
...@@ -96,7 +96,7 @@ def run_dist(rank, world_size, port, parallel_config): ...@@ -96,7 +96,7 @@ def run_dist(rank, world_size, port, parallel_config):
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4]) @pytest.mark.parametrize("world_size", [2, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_mp_engine(world_size): def test_mp_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG) run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
...@@ -104,7 +104,7 @@ def test_mp_engine(world_size): ...@@ -104,7 +104,7 @@ def test_mp_engine(world_size):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_zero_engine(world_size): def test_zero_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG) run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment