Unverified Commit 5a1a095b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactored with the new rerun decorator (#763)

* [test] refactored with the new rerun decorator

* polish test case
parent deaf99f4
...@@ -3,7 +3,7 @@ import colossalai ...@@ -3,7 +3,7 @@ import colossalai
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import assert_close_loose, rerun_on_exception from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
...@@ -84,7 +84,7 @@ def run_dist(rank, world_size, port): ...@@ -84,7 +84,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_naive_amp(): def test_naive_amp():
world_size = 1 world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
......
...@@ -9,7 +9,7 @@ from colossalai.context import ParallelMode ...@@ -9,7 +9,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
...@@ -64,7 +64,7 @@ def check_layer(rank, world_size, port): ...@@ -64,7 +64,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_comm(): def test_comm():
world_size = 4 world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port()) run_func = partial(check_layer, world_size=world_size, port=free_port())
......
...@@ -13,7 +13,7 @@ from colossalai.core import global_context as gpc ...@@ -13,7 +13,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.context import reset_seeds from colossalai.context import reset_seeds
from colossalai.global_variables import tensor_parallel_env as tp_env from colossalai.global_variables import tensor_parallel_env as tp_env
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py'))
...@@ -141,7 +141,7 @@ def run_dist(rank, world_size, backend, port_list, host): ...@@ -141,7 +141,7 @@ def run_dist(rank, world_size, backend, port_list, host):
@pytest.mark.cpu @pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_context(): def test_context():
""" """
As no computation or communication is done, we can run this test on CPU. As no computation or communication is done, we can run this test on CPU.
......
...@@ -17,7 +17,7 @@ from torchvision import transforms ...@@ -17,7 +17,7 @@ from torchvision import transforms
from colossalai.context import ParallelMode, Config from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader, free_port from colossalai.utils import get_dataloader, free_port
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
CONFIG = Config( CONFIG = Config(
dict( dict(
...@@ -67,7 +67,7 @@ def run_data_sampler(rank, world_size, port): ...@@ -67,7 +67,7 @@ def run_data_sampler(rank, world_size, port):
@pytest.mark.cpu @pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_data_sampler(): def test_data_sampler():
world_size = 4 world_size = 4
test_func = partial(run_data_sampler, world_size=world_size, port=free_port()) test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
......
...@@ -17,7 +17,7 @@ from colossalai.builder import build_dataset, build_transform ...@@ -17,7 +17,7 @@ from colossalai.builder import build_dataset, build_transform
from colossalai.context import ParallelMode, Config from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
CONFIG = Config( CONFIG = Config(
dict( dict(
...@@ -79,7 +79,7 @@ def run_data_sampler(rank, world_size, port): ...@@ -79,7 +79,7 @@ def run_data_sampler(rank, world_size, port):
@pytest.mark.cpu @pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_data_sampler(): def test_data_sampler():
world_size = 4 world_size = 4
test_func = partial(run_data_sampler, world_size=world_size, port=free_port()) test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
......
...@@ -15,7 +15,7 @@ from colossalai.nn.loss import CrossEntropyLoss ...@@ -15,7 +15,7 @@ from colossalai.nn.loss import CrossEntropyLoss
from colossalai.trainer import Trainer, hooks from colossalai.trainer import Trainer, hooks
from colossalai.utils import free_port, get_dataloader from colossalai.utils import free_port, get_dataloader
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from model_zoo.vit import vit_tiny_patch4_32 from model_zoo.vit import vit_tiny_patch4_32
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
...@@ -23,7 +23,8 @@ from torchvision.datasets import CIFAR10 ...@@ -23,7 +23,8 @@ from torchvision.datasets import CIFAR10
BATCH_SIZE = 4 BATCH_SIZE = 4
NUM_EPOCHS = 60 NUM_EPOCHS = 60
WARMUP_EPOCHS = 5 WARMUP_EPOCHS = 5
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')), CONFIG = dict(NUM_MICRO_BATCHES=2,
parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE), fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2) gradient_accumulation=2)
...@@ -79,7 +80,7 @@ def run_trainer(rank, world_size, port): ...@@ -79,7 +80,7 @@ def run_trainer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_hybrid_parallel(): def test_hybrid_parallel():
world_size = 8 world_size = 8
run_func = partial(run_trainer, world_size=world_size, port=free_port()) run_func = partial(run_trainer, world_size=world_size, port=free_port())
......
...@@ -7,7 +7,7 @@ from colossalai.amp import AMP_TYPE ...@@ -7,7 +7,7 @@ from colossalai.amp import AMP_TYPE
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import free_port from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize, rerun_on_exception from colossalai.testing import parameterize, rerun_if_address_is_in_use
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
fp16=dict(mode=None), fp16=dict(mode=None),
...@@ -56,7 +56,7 @@ def run_engine(rank, world_size, port): ...@@ -56,7 +56,7 @@ def run_engine(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_engine(): def test_engine():
world_size = 2 world_size = 2
run_func = partial(run_engine, world_size=world_size, port=free_port()) run_func = partial(run_engine, world_size=world_size, port=free_port())
......
...@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc ...@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from checks_1d.check_layer_1d import * from checks_1d.check_layer_1d import *
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),)
...@@ -35,7 +35,7 @@ def check_layer(rank, world_size, port): ...@@ -35,7 +35,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_1d(): def test_1d():
world_size = 4 world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port()) run_func = partial(check_layer, world_size=world_size, port=free_port())
......
...@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc ...@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight, check_vocab_parallel_classifier_given_embed_weight,
...@@ -55,7 +55,7 @@ def check_layer_and_operation(rank, world_size, port): ...@@ -55,7 +55,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_2d(): def test_2d():
world_size = 4 world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
......
...@@ -7,7 +7,7 @@ from colossalai.core import global_context as gpc ...@@ -7,7 +7,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from checks_2p5d.check_layer_2p5d import * from checks_2p5d.check_layer_2p5d import *
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
...@@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port): ...@@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_2p5d(): def test_2p5d():
world_size = 4 world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
......
...@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc ...@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight, check_vocab_parallel_classifier_given_embed_weight,
...@@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port): ...@@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_3d(): def test_3d():
world_size = 8 world_size = 8
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
......
...@@ -7,7 +7,7 @@ import pytest ...@@ -7,7 +7,7 @@ import pytest
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from functools import partial from functools import partial
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence'))) CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
...@@ -132,7 +132,7 @@ def run_test(rank, world_size): ...@@ -132,7 +132,7 @@ def run_test(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_sequence(): def test_sequence():
world_size = 4 world_size = 4
run_func = partial(run_test, world_size=world_size) run_func = partial(run_test, world_size=world_size)
......
...@@ -10,8 +10,7 @@ from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, ...@@ -10,8 +10,7 @@ from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer,
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param from colossalai.utils.moe import sync_moe_model_param
from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.testing import assert_equal_in_group from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use
from colossalai.testing import rerun_on_exception
BATCH_SIZE = 4 BATCH_SIZE = 4
DIM = 16 DIM = 16
...@@ -63,7 +62,7 @@ def run_test(rank, world_size, port): ...@@ -63,7 +62,7 @@ def run_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_grad_handler(): def test_grad_handler():
world_size = 4 world_size = 4
run_func = partial(run_test, world_size=world_size, port=free_port()) run_func = partial(run_test, world_size=world_size, port=free_port())
......
...@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc ...@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
BATCH_SIZE = 16 BATCH_SIZE = 16
NUM_EXPERTS = 4 NUM_EXPERTS = 4
...@@ -87,7 +87,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f ...@@ -87,7 +87,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@pytest.mark.parametrize("hidden_size", [32, 144]) @pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) @pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
@pytest.mark.parametrize("router", [Top1Router, Top2Router]) @pytest.mark.parametrize("router", [Top1Router, Top2Router])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_moe_kernel(rs, hidden_size, data_type, router): def test_moe_kernel(rs, hidden_size, data_type, router):
world_size = 4 world_size = 4
run_func = partial(run_routing, run_func = partial(run_routing,
......
...@@ -8,7 +8,7 @@ from colossalai.utils import free_port, get_current_device ...@@ -8,7 +8,7 @@ from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Experts from colossalai.nn.layer.moe import Experts
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param from colossalai.utils.moe import sync_moe_model_param
from colossalai.testing import assert_equal_in_group, rerun_on_exception from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use
D_MODEL = 4 D_MODEL = 4
D_FF = 8 D_FF = 8
...@@ -60,7 +60,7 @@ def run_test(rank, port): ...@@ -60,7 +60,7 @@ def run_test(rank, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_moe_initialization(): def test_moe_initialization():
world_size = 4 world_size = 4
run_func = partial(run_test, port=free_port()) run_func = partial(run_test, port=free_port())
......
...@@ -14,7 +14,7 @@ from colossalai.nn.layer import MoeModule ...@@ -14,7 +14,7 @@ from colossalai.nn.layer import MoeModule
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from tests.test_zero.common import CONFIG from tests.test_zero.common import CONFIG
...@@ -91,7 +91,7 @@ def _run_dist(rank, world_size, port): ...@@ -91,7 +91,7 @@ def _run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4]) @pytest.mark.parametrize("world_size", [2, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_moe_zero_init(world_size): def test_moe_zero_init(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port()) run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -4,7 +4,7 @@ import colossalai ...@@ -4,7 +4,7 @@ import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_on_exception from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
...@@ -65,7 +65,7 @@ def run_dist(rank, world_size, port): ...@@ -65,7 +65,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_moe_zero_model(world_size): def test_moe_zero_model(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import CPUAdam from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize, rerun_on_exception from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
...@@ -120,7 +120,7 @@ def _run_dist(rank, world_size, port): ...@@ -120,7 +120,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False # use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size): def test_moe_zero_optim(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port()) run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
......
...@@ -9,7 +9,7 @@ from colossalai.logging import get_dist_logger ...@@ -9,7 +9,7 @@ from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, free_port from colossalai.utils import MultiTimer, free_port
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize, rerun_on_exception from colossalai.testing import parameterize, rerun_if_address_is_in_use
BATCH_SIZE = 4 BATCH_SIZE = 4
IMG_SIZE = 32 IMG_SIZE = 32
...@@ -51,7 +51,7 @@ def run_dist(rank, world_size, port): ...@@ -51,7 +51,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_trainer_no_pipeline(): def test_trainer_no_pipeline():
world_size = 4 world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
......
...@@ -17,13 +17,16 @@ from torch.optim import Adam ...@@ -17,13 +17,16 @@ from torch.optim import Adam
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from torchvision.models import resnet18 from torchvision.models import resnet18
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_if_address_is_in_use
BATCH_SIZE = 4 BATCH_SIZE = 4
IMG_SIZE = 32 IMG_SIZE = 32
NUM_EPOCHS = 200 NUM_EPOCHS = 200
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2),) CONFIG = dict(
NUM_MICRO_BATCHES=2,
parallel=dict(pipeline=2),
)
def run_trainer_with_pipeline(rank, world_size, port): def run_trainer_with_pipeline(rank, world_size, port):
...@@ -85,7 +88,7 @@ def run_trainer_with_pipeline(rank, world_size, port): ...@@ -85,7 +88,7 @@ def run_trainer_with_pipeline(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_trainer_with_pipeline(): def test_trainer_with_pipeline():
world_size = 4 world_size = 4
run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port()) run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment