"examples/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "c7925c5d089d9927d93525a92180a3b66c649091"
Unverified Commit 3601b2ba authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] fixed rerun_on_exception and adapted test cases (#487)

parent 4d322b79
...@@ -17,6 +17,7 @@ from torch.optim import Adam ...@@ -17,6 +17,7 @@ from torch.optim import Adam
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from torchvision.models import resnet18 from torchvision.models import resnet18
from colossalai.testing import rerun_on_exception
BATCH_SIZE = 4 BATCH_SIZE = 4
IMG_SIZE = 32 IMG_SIZE = 32
...@@ -85,6 +86,7 @@ def run_trainer_with_pipeline(rank, world_size, port): ...@@ -85,6 +86,7 @@ def run_trainer_with_pipeline(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_trainer_with_pipeline(): def test_trainer_with_pipeline():
world_size = 4 world_size = 4
run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port()) run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port())
......
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param import ShardedTensor
import colossalai import colossalai
...@@ -47,6 +47,7 @@ def run_tensor_move(rank): ...@@ -47,6 +47,7 @@ def run_tensor_move(rank):
GLOBAL_MODEL_DATA_TRACER.close() GLOBAL_MODEL_DATA_TRACER.close()
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_tensor_move(): def test_tensor_move():
mp.spawn(run_tensor_move, nprocs=1) mp.spawn(run_tensor_move, nprocs=1)
......
...@@ -10,6 +10,7 @@ import torch.nn as nn ...@@ -10,6 +10,7 @@ import torch.nn as nn
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import free_port, get_dataloader from colossalai.utils import free_port, get_dataloader
from colossalai.testing import rerun_on_exception
from torch.optim import Adam from torch.optim import Adam
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
...@@ -86,6 +87,7 @@ def run_no_pipeline(rank, world_size, port): ...@@ -86,6 +87,7 @@ def run_no_pipeline(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_engine(): def test_engine():
world_size = 4 world_size = 4
func = partial(run_no_pipeline, world_size=world_size, port=free_port()) func = partial(run_no_pipeline, world_size=world_size, port=free_port())
......
...@@ -14,9 +14,9 @@ from colossalai.logging import disable_existing_loggers ...@@ -14,9 +14,9 @@ from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from colossalai.testing import parameterize
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from functools import partial from functools import partial
from colossalai.testing import parameterize, rerun_on_exception
def checkpoint_wrapper(module, enable=True): def checkpoint_wrapper(module, enable=True):
...@@ -102,6 +102,7 @@ def run_dist(rank, world_size, port): ...@@ -102,6 +102,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_clip_grad(): def test_zero_clip_grad():
world_size = 4 world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
......
...@@ -14,6 +14,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \ ...@@ -14,6 +14,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
...@@ -57,6 +58,7 @@ def run_dist(rank, world_size, port): ...@@ -57,6 +58,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 4]) @pytest.mark.parametrize("world_size", [1, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_init_context(world_size): def test_zero_init_context(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -14,6 +14,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS ...@@ -14,6 +14,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
...@@ -63,6 +64,7 @@ def run_dist(rank, world_size, port): ...@@ -63,6 +64,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_shard_model_v2(world_size): def test_shard_model_v2(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -10,6 +10,7 @@ from colossalai.utils import free_port ...@@ -10,6 +10,7 @@ from colossalai.utils import free_port
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.testing import rerun_on_exception
from tests.test_zero_data_parallel.common import CONFIG, allclose from tests.test_zero_data_parallel.common import CONFIG, allclose
...@@ -35,6 +36,7 @@ def _run_shard_tensor(rank, world_size, port): ...@@ -35,6 +36,7 @@ def _run_shard_tensor(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_shard_tensor(world_size): def test_shard_tensor(world_size):
run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port()) run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
...@@ -55,6 +57,7 @@ def _run_shard_param_v2(rank, world_size, port): ...@@ -55,6 +57,7 @@ def _run_shard_param_v2(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_shard_param_v2(world_size): def test_shard_param_v2(world_size):
run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port()) run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -15,6 +15,7 @@ from colossalai.zero.sharded_model import ShardedModelV2 ...@@ -15,6 +15,7 @@ from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.zero.sharded_optim._utils import has_inf_or_nan from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
...@@ -106,6 +107,7 @@ def _run_dist(rank, world_size, port): ...@@ -106,6 +107,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False # use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_sharded_optim_v2(world_size): def test_sharded_optim_v2(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port()) run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -14,6 +14,7 @@ from colossalai.utils import free_port ...@@ -14,6 +14,7 @@ from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import TensorShardStrategy
from torchvision.models import resnet50 from torchvision.models import resnet50
from colossalai.testing import rerun_on_exception
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
...@@ -71,6 +72,7 @@ def run_dist(rank, world_size, port): ...@@ -71,6 +72,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_sharded_optim_with_sync_bn(): def test_sharded_optim_with_sync_bn():
""" """
This test is to make sure that buffers are synchronized between ranks This test is to make sure that buffers are synchronized between ranks
......
...@@ -14,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext ...@@ -14,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG from common import CONFIG
...@@ -51,6 +52,7 @@ def run_dist(rank, world_size, port): ...@@ -51,6 +52,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_state_dict(world_size): def test_zero_state_dict(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -13,6 +13,7 @@ from colossalai.utils import free_port ...@@ -13,6 +13,7 @@ from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim._utils import has_inf_or_nan from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.testing import rerun_on_exception
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
...@@ -96,6 +97,7 @@ def run_dist(rank, world_size, port, parallel_config): ...@@ -96,6 +97,7 @@ def run_dist(rank, world_size, port, parallel_config):
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4]) @pytest.mark.parametrize("world_size", [2, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_mp_engine(world_size): def test_mp_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG) run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
...@@ -103,6 +105,7 @@ def test_mp_engine(world_size): ...@@ -103,6 +105,7 @@ def test_mp_engine(world_size):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_zero_engine(world_size): def test_zero_engine(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG) run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment