Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
import pytest
import torch
from colossalai.testing import clear_cache_before_run, parameterize
try:
from colossalai._analyzer.fx import symbolic_trace
except:
......@@ -62,9 +64,10 @@ class AModel(torch.nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("bias_addition_split", [True, False])
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
@clear_cache_before_run()
@parameterize("bias", [True, False])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
def test_mod_dir(bias, bias_addition_split, shape):
model = AModel(bias=bias)
x = torch.rand(shape)
......@@ -75,4 +78,4 @@ def test_mod_dir(bias, bias_addition_split, shape):
if __name__ == '__main__':
test_mod_dir(True, True, (3, 3, 3))
test_mod_dir(bias=True, bias_addition_split=True, shape=(3, 3, 3))
import pytest
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import pytest
from colossalai.testing import clear_cache_before_run
try:
from colossalai._analyzer.fx import symbolic_trace
......@@ -42,6 +44,7 @@ class MyModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@clear_cache_before_run()
def test_nested_ckpt():
model = MyModule()
x = torch.rand(10, 10)
......
......@@ -3,7 +3,7 @@ import torch
import torchvision.models as tm
from packaging import version
from colossalai.testing.utils import parameterize
from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try:
......@@ -32,6 +32,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tm_models)
def test_torchvision_shape_prop(m):
with MetaTensorMode():
......@@ -46,6 +47,7 @@ def test_torchvision_shape_prop(m):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tmm_models)
def test_timm_shape_prop(m):
with MetaTensorMode():
......
......@@ -3,7 +3,7 @@ import torch
import torchvision.models as tm
from packaging import version
from colossalai.testing.utils import parameterize
from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try:
......@@ -19,6 +19,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tm_models)
def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode():
......@@ -33,6 +34,7 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tmm_models)
def test_timm_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode():
......
from typing import Any, Callable, Union
import pytest
import pytest
import torch
import torch.nn as nn
from colossalai.testing import clear_cache_before_run
try:
from colossalai._analyzer._subclasses import MetaTensor
except:
......@@ -72,6 +74,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@clear_cache_before_run()
def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v:
......
......@@ -4,6 +4,7 @@ import torch.nn.functional as F
import torchvision.models as tm
from packaging import version
from colossalai.testing import clear_cache_before_run, parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try:
......@@ -39,7 +40,8 @@ odd_cases = [
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('func, args, kwargs', odd_cases)
@clear_cache_before_run()
@parameterize('func, args, kwargs', odd_cases)
def test_flop_count_function(func, args, kwargs):
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'
......
......@@ -3,6 +3,8 @@ import torch
import torchvision.models as tm
from packaging import version
from colossalai.testing import clear_cache_before_run, parameterize
try:
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
except:
......@@ -30,7 +32,8 @@ def run_and_compare(model):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models + tmm_models)
@clear_cache_before_run()
@parameterize('m', tm_models + tmm_models)
def test_meta_mode_shape(m):
run_and_compare(m())
......
......@@ -3,7 +3,6 @@ import copy
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
import torchvision.models as tm
import colossalai
......@@ -13,7 +12,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
# from colossalai.fx.passes.algorithms import solver_rotor
# from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
......@@ -26,8 +25,8 @@ except:
withcodegen = False
def _run_C_solver_consistency_test(rank=0):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
def _run_C_solver_consistency_test(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
model = M()
......@@ -70,8 +69,9 @@ def _run_C_solver_consistency_test(rank=0):
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
@rerun_if_address_is_in_use()
def test_C_solver_consistency():
mp.spawn(_run_C_solver_consistency_test, nprocs=1)
spawn(_run_C_solver_consistency_test, 1)
if __name__ == '__main__':
......
......@@ -4,7 +4,6 @@ from typing import Callable
import pytest
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from torch.fx import GraphModule
......@@ -15,7 +14,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
......@@ -68,8 +67,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
def _run_ckpt_solver(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
def _run_ckpt_solver(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
......@@ -98,12 +97,13 @@ def _run_ckpt_solver(rank):
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
@rerun_if_address_is_in_use()
def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1)
spawn(_run_ckpt_solver, 1)
def _run_ckpt_solver_torch11(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
def _run_ckpt_solver_torch11(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
......@@ -131,8 +131,9 @@ def _run_ckpt_solver_torch11(rank):
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
@rerun_if_address_is_in_use()
def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
spawn(_run_ckpt_solver_torch11, 1)
if __name__ == '__main__':
......
......@@ -8,6 +8,7 @@ from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import linearize, solver_rotor
# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.testing import clear_cache_before_run
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
......@@ -24,6 +25,7 @@ except:
@pytest.mark.skip(reason='TODO: modify the logger')
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
@clear_cache_before_run()
def test_linearize():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer()
......@@ -84,6 +86,7 @@ def test_linearize():
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skip(reason="torch11 meta tensor not implemented")
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
@clear_cache_before_run()
def test_linearize_torch11():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer()
......
import time
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.utils._pytree import tree_map
import colossalai
......@@ -12,8 +10,8 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize
from colossalai.utils import free_port, get_current_device
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
from tests.test_auto_parallel.test_offload.model_utils import *
from tests.test_tensor.common_utils import set_seed
......@@ -140,9 +138,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.skip("this test failed")
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
@rerun_if_address_is_in_use()
def test_perf():
run_func = partial(run_dist, world_size=1, port=free_port())
mp.spawn(run_func, nprocs=1)
spawn(run_dist, 1)
if __name__ == '__main__':
......
......@@ -3,20 +3,20 @@ import torch.fx
from torch.fx import GraphModule
from torch.utils._pytree import tree_map
from colossalai.auto_parallel.offload.region_manager import RegionManager
from colossalai.auto_parallel.offload.solver import NOT_NVML, SolverFactory
from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.auto_parallel.offload.region_manager import RegionManager
from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML
from colossalai.testing import parameterize
from colossalai.testing import clear_cache_before_run, parameterize
from tests.test_auto_parallel.test_offload.model_utils import *
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
@clear_cache_before_run()
@parameterize('model_name', ['gpt2_', 'bert_'])
@parameterize('memory_budget', [4000])
@parameterize('solver_name', ['syn', 'asyn'])
def solver_test(model_name: str,
memory_budget: float,
solver_name: str):
def solver_test(model_name: str, memory_budget: float, solver_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen = get_components_func()
......@@ -52,11 +52,16 @@ def solver_test(model_name: str,
for region in region_list:
need_offload = region.need_offload
to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None
print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
print(
f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
)
for region in region_list.__reversed__():
need_offload = region.need_offload
to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None
print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
print(
f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
)
if __name__ == '__main__':
solver_test()
\ No newline at end of file
solver_test()
......@@ -6,6 +6,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import clear_cache_before_run
class TestModule(torch.nn.Module):
......@@ -26,6 +27,7 @@ def insert_narrow(gm, x_node):
return gm
@clear_cache_before_run()
def test_node_args_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)
......
......@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import clear_cache_before_run
class TestModule(torch.nn.Module):
......@@ -36,6 +37,7 @@ def recover_narrow(gm, narrow_node):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_size_value_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)
......
......@@ -2,7 +2,6 @@ from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
try:
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
......@@ -13,9 +12,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
class LinearModel(torch.nn.Module):
......@@ -86,11 +83,8 @@ def check_conv_module(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bias_addition_module():
world_size = 4
run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port())
mp.spawn(run_func_linear, nprocs=world_size)
run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port())
mp.spawn(run_func_conv, nprocs=world_size)
spawn(check_linear_module, 4)
spawn(check_conv_module, 4)
if __name__ == '__main__':
......
from functools import partial
from typing import Optional, Tuple, Union
from typing import Optional, Tuple
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from transformers.pytorch_utils import Conv1D
......@@ -17,9 +15,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
HIDDEN_SIZE = 16
......@@ -65,9 +61,7 @@ def check_act_ckpt(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_mlp_layer():
world_size = 4
run_func = partial(check_act_ckpt, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_act_ckpt, 4)
if __name__ == '__main__':
......
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
try:
......@@ -15,9 +13,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
class MLP(torch.nn.Module):
......@@ -102,9 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_compatibility_with_ddp():
world_size = 4
run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_compatibility_with_ddp, 4)
if __name__ == '__main__':
......
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
try:
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
......@@ -17,10 +14,9 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.process_group import ProcessGroup
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port, get_current_device
from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.utils import get_current_device
from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
class MLP(torch.nn.Module):
......@@ -110,9 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_auto_parallel_with_gemini():
world_size = 4
run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_auto_parallel_with_gemini, 4)
if __name__ == '__main__':
......
......@@ -10,8 +10,7 @@ from colossalai._analyzer.fx.passes import shape_prop_pass
# from colossalai.fx.tracer.tracer import ColoTracer
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, parameterize, run_on_environment_flag
NUM_REPEAT_BLOCKS = 4
BATCH_SIZE = 1
......@@ -81,6 +80,7 @@ class NonRepeatModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
@parameterize('model_cls', [RepeatModel, NonRepeatModel])
def test_repeat_blocks(model_cls):
......
import copy
import random
from functools import partial
from typing import Dict
import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
import transformers
from torch.fx import GraphModule
......@@ -30,9 +28,8 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import to_global
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
BATCH_SIZE = 1
......@@ -190,9 +187,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
@rerun_if_address_is_in_use()
def test_mlp_layer(model_cls):
world_size = 4
run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_attention_layer, 4, model_cls=model_cls)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment