Unverified Commit 5a1a095b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactored with the new rerun decorator (#763)

* [test] refactored with the new rerun decorator

* polish test case
parent deaf99f4
......@@ -3,7 +3,7 @@ import colossalai
import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import assert_close_loose, rerun_on_exception
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
......@@ -84,7 +84,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_naive_amp():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())
......
......@@ -9,7 +9,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port, get_current_device
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
......@@ -64,7 +64,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_comm():
world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port())
......
......@@ -13,7 +13,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.context import reset_seeds
from colossalai.global_variables import tensor_parallel_env as tp_env
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py'))
......@@ -141,7 +141,7 @@ def run_dist(rank, world_size, backend, port_list, host):
@pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_context():
"""
As no computation or communication is done, we can run this test on CPU.
......
......@@ -17,7 +17,7 @@ from torchvision import transforms
from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader, free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
CONFIG = Config(
dict(
......@@ -67,7 +67,7 @@ def run_data_sampler(rank, world_size, port):
@pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_data_sampler():
world_size = 4
test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
......
......@@ -17,7 +17,7 @@ from colossalai.builder import build_dataset, build_transform
from colossalai.context import ParallelMode, Config
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
CONFIG = Config(
dict(
......@@ -79,7 +79,7 @@ def run_data_sampler(rank, world_size, port):
@pytest.mark.cpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_data_sampler():
world_size = 4
test_func = partial(run_data_sampler, world_size=world_size, port=free_port())
......
......@@ -15,7 +15,7 @@ from colossalai.nn.loss import CrossEntropyLoss
from colossalai.trainer import Trainer, hooks
from colossalai.utils import free_port, get_dataloader
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from model_zoo.vit import vit_tiny_patch4_32
from torchvision import transforms
from torchvision.datasets import CIFAR10
......@@ -23,9 +23,10 @@ from torchvision.datasets import CIFAR10
BATCH_SIZE = 4
NUM_EPOCHS = 60
WARMUP_EPOCHS = 5
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)
CONFIG = dict(NUM_MICRO_BATCHES=2,
parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)
def run_trainer(rank, world_size, port):
......@@ -79,7 +80,7 @@ def run_trainer(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_hybrid_parallel():
world_size = 8
run_func = partial(run_trainer, world_size=world_size, port=free_port())
......
......@@ -7,7 +7,7 @@ from colossalai.amp import AMP_TYPE
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
fp16=dict(mode=None),
......@@ -56,7 +56,7 @@ def run_engine(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_engine():
world_size = 2
run_func = partial(run_engine, world_size=world_size, port=free_port())
......
......@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from checks_1d.check_layer_1d import *
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),)
......@@ -35,7 +35,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_1d():
world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port())
......
......@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
......@@ -55,7 +55,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_2d():
world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
......
......@@ -7,7 +7,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from checks_2p5d.check_layer_2p5d import *
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
......@@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_2p5d():
world_size = 4
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
......
......@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
......@@ -51,7 +51,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_3d():
world_size = 8
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
......
......@@ -7,7 +7,7 @@ import pytest
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from functools import partial
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
......@@ -132,7 +132,7 @@ def run_test(rank, world_size):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_sequence():
world_size = 4
run_func = partial(run_test, world_size=world_size)
......
......@@ -10,8 +10,7 @@ from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer,
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.testing import assert_equal_in_group
from colossalai.testing import rerun_on_exception
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use
BATCH_SIZE = 4
DIM = 16
......@@ -63,7 +62,7 @@ def run_test(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_grad_handler():
world_size = 4
run_func = partial(run_test, world_size=world_size, port=free_port())
......
......@@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
BATCH_SIZE = 16
NUM_EXPERTS = 4
......@@ -87,7 +87,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
@pytest.mark.parametrize("router", [Top1Router, Top2Router])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_moe_kernel(rs, hidden_size, data_type, router):
world_size = 4
run_func = partial(run_routing,
......
......@@ -8,7 +8,7 @@ from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Experts
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.testing import assert_equal_in_group, rerun_on_exception
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use
D_MODEL = 4
D_FF = 8
......@@ -60,7 +60,7 @@ def run_test(rank, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_moe_initialization():
world_size = 4
run_func = partial(run_test, port=free_port())
......
......@@ -14,7 +14,7 @@ from colossalai.nn.layer import MoeModule
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import get_current_device
from tests.test_zero.common import CONFIG
......@@ -91,7 +91,7 @@ def _run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_moe_zero_init(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -4,7 +4,7 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
......@@ -65,7 +65,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_moe_zero_model(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -6,7 +6,7 @@ import torch
import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
......@@ -120,7 +120,7 @@ def _run_dist(rank, world_size, port):
# use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
......
......@@ -9,7 +9,7 @@ from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, free_port
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.testing import parameterize, rerun_if_address_is_in_use
BATCH_SIZE = 4
IMG_SIZE = 32
......@@ -51,7 +51,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_trainer_no_pipeline():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
......
......@@ -17,13 +17,16 @@ from torch.optim import Adam
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use
BATCH_SIZE = 4
IMG_SIZE = 32
NUM_EPOCHS = 200
CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=2),)
CONFIG = dict(
NUM_MICRO_BATCHES=2,
parallel=dict(pipeline=2),
)
def run_trainer_with_pipeline(rank, world_size, port):
......@@ -85,7 +88,7 @@ def run_trainer_with_pipeline(rank, world_size, port):
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_trainer_with_pipeline():
world_size = 4
run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment