Unverified Commit 5a1a095b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactored with the new rerun decorator (#763)

* [test] refactored with the new rerun decorator

* polish test case
parent deaf99f4
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.zero.sharded_param import ShardedTensor
import colossalai
......@@ -35,7 +35,7 @@ def run_tensor_move(rank):
assert (tgt_t.device.type == 'cpu')
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_tensor_move():
mp.spawn(run_tensor_move, nprocs=1)
......
......@@ -3,6 +3,7 @@ from functools import partial
from pathlib import Path
import colossalai
from colossalai.testing.utils import rerun_if_address_is_in_use
import pytest
import torch
import torch.multiprocessing as mp
......@@ -10,7 +11,7 @@ import torch.nn as nn
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import free_port, get_dataloader
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
......@@ -87,7 +88,7 @@ def run_no_pipeline(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_engine():
world_size = 4
func = partial(run_no_pipeline, world_size=world_size, port=free_port())
......
......@@ -16,7 +16,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from functools import partial
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
def checkpoint_wrapper(module, enable=True):
......@@ -102,7 +102,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_zero_clip_grad():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
......
......@@ -6,7 +6,7 @@ import pytest
import torch
import torch.multiprocessing as mp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import BucketTensorShardStrategy
......@@ -62,7 +62,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_found_inf(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -8,7 +8,7 @@ import pytest
import torch
import torch.multiprocessing as mp
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import \
......@@ -64,7 +64,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_zero_init_context(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -10,7 +10,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from functools import partial
......@@ -64,7 +64,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_mem_collector(world_size=2):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -7,7 +7,7 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
......@@ -59,7 +59,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_shard_model_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -5,12 +5,11 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.testing import rerun_on_exception
from tests.test_zero.common import CONFIG, allclose
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
......@@ -37,7 +36,7 @@ def _run_shard_tensor(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_shard_tensor(world_size):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......@@ -85,7 +84,7 @@ def _run_shard_param_v2(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_shard_param_v2(world_size):
run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -8,7 +8,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
......@@ -105,7 +105,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_sharded_optim_v2(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -10,7 +10,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
......@@ -71,7 +71,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_sharded_optim_with_sync_bn():
"""
This test is to make sure that buffers are synchronized between ranks
......
......@@ -8,7 +8,7 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
......@@ -49,7 +49,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_zero_state_dict(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -10,7 +10,7 @@ from colossalai.gemini import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from torch.nn.parameter import Parameter
from typing import List
from functools import partial
......@@ -120,8 +120,8 @@ def run_dist(rank, world_size, port):
run_stm()
@pytest.mark.gpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_stateful_tensor_manager(world_size=1):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -6,6 +6,7 @@ from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
colo_model_tensor_clone)
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
import torch
......@@ -84,6 +85,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4, 5])
@rerun_if_address_is_in_use()
def test_zero_tensor_utils(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -9,7 +9,7 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model.utils import col_model_deepcopy
......@@ -96,7 +96,7 @@ def run_dist(rank, world_size, port, parallel_config):
@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_mp_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size)
......@@ -104,7 +104,7 @@ def test_mp_engine(world_size):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_zero_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment