Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
import pytest import pytest
import torch import torch
from colossalai.testing import clear_cache_before_run, parameterize
try: try:
from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace
except: except:
...@@ -62,9 +64,10 @@ class AModel(torch.nn.Module): ...@@ -62,9 +64,10 @@ class AModel(torch.nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize("bias", [True, False]) @clear_cache_before_run()
@pytest.mark.parametrize("bias_addition_split", [True, False]) @parameterize("bias", [True, False])
@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)]) @parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
def test_mod_dir(bias, bias_addition_split, shape): def test_mod_dir(bias, bias_addition_split, shape):
model = AModel(bias=bias) model = AModel(bias=bias)
x = torch.rand(shape) x = torch.rand(shape)
...@@ -75,4 +78,4 @@ def test_mod_dir(bias, bias_addition_split, shape): ...@@ -75,4 +78,4 @@ def test_mod_dir(bias, bias_addition_split, shape):
if __name__ == '__main__': if __name__ == '__main__':
test_mod_dir(True, True, (3, 3, 3)) test_mod_dir(bias=True, bias_addition_split=True, shape=(3, 3, 3))
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import pytest
from colossalai.testing import clear_cache_before_run
try: try:
from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace
...@@ -42,6 +44,7 @@ class MyModule(nn.Module): ...@@ -42,6 +44,7 @@ class MyModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@clear_cache_before_run()
def test_nested_ckpt(): def test_nested_ckpt():
model = MyModule() model = MyModule()
x = torch.rand(10, 10) x = torch.rand(10, 10)
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
import torchvision.models as tm import torchvision.models as tm
from packaging import version from packaging import version
from colossalai.testing.utils import parameterize from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try: try:
...@@ -32,6 +32,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule): ...@@ -32,6 +32,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tm_models) @parameterize('m', tm_models)
def test_torchvision_shape_prop(m): def test_torchvision_shape_prop(m):
with MetaTensorMode(): with MetaTensorMode():
...@@ -46,6 +47,7 @@ def test_torchvision_shape_prop(m): ...@@ -46,6 +47,7 @@ def test_torchvision_shape_prop(m):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tmm_models) @parameterize('m', tmm_models)
def test_timm_shape_prop(m): def test_timm_shape_prop(m):
with MetaTensorMode(): with MetaTensorMode():
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
import torchvision.models as tm import torchvision.models as tm
from packaging import version from packaging import version
from colossalai.testing.utils import parameterize from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try: try:
...@@ -19,6 +19,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule): ...@@ -19,6 +19,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tm_models) @parameterize('m', tm_models)
def test_torchvision_profile(m, verbose=False, bias_addition_split=False): def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode(): with MetaTensorMode():
...@@ -33,6 +34,7 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False): ...@@ -33,6 +34,7 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize('m', tmm_models) @parameterize('m', tmm_models)
def test_timm_profile(m, verbose=False, bias_addition_split=False): def test_timm_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode(): with MetaTensorMode():
......
from typing import Any, Callable, Union from typing import Any, Callable, Union
import pytest
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.testing import clear_cache_before_run
try: try:
from colossalai._analyzer._subclasses import MetaTensor from colossalai._analyzer._subclasses import MetaTensor
except: except:
...@@ -72,6 +74,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac ...@@ -72,6 +74,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@clear_cache_before_run()
def test_meta_aten(): def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items(): for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v: for f, x in v:
......
...@@ -4,6 +4,7 @@ import torch.nn.functional as F ...@@ -4,6 +4,7 @@ import torch.nn.functional as F
import torchvision.models as tm import torchvision.models as tm
from packaging import version from packaging import version
from colossalai.testing import clear_cache_before_run, parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try: try:
...@@ -39,7 +40,8 @@ odd_cases = [ ...@@ -39,7 +40,8 @@ odd_cases = [
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('func, args, kwargs', odd_cases) @clear_cache_before_run()
@parameterize('func, args, kwargs', odd_cases)
def test_flop_count_function(func, args, kwargs): def test_flop_count_function(func, args, kwargs):
rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True) rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}' assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'
......
...@@ -3,6 +3,8 @@ import torch ...@@ -3,6 +3,8 @@ import torch
import torchvision.models as tm import torchvision.models as tm
from packaging import version from packaging import version
from colossalai.testing import clear_cache_before_run, parameterize
try: try:
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
except: except:
...@@ -30,7 +32,8 @@ def run_and_compare(model): ...@@ -30,7 +32,8 @@ def run_and_compare(model):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models + tmm_models) @clear_cache_before_run()
@parameterize('m', tm_models + tmm_models)
def test_meta_mode_shape(m): def test_meta_mode_shape(m):
run_and_compare(m()) run_and_compare(m())
......
...@@ -3,7 +3,6 @@ import copy ...@@ -3,7 +3,6 @@ import copy
import pytest import pytest
import torch import torch
import torch.fx import torch.fx
import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
import colossalai import colossalai
...@@ -13,7 +12,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta ...@@ -13,7 +12,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
# from colossalai.fx.passes.algorithms import solver_rotor # from colossalai.fx.passes.algorithms import solver_rotor
# from colossalai.fx.passes.algorithms.operation import Sequence # from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta(): if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
...@@ -26,8 +25,8 @@ except: ...@@ -26,8 +25,8 @@ except:
withcodegen = False withcodegen = False
def _run_C_solver_consistency_test(rank=0): def _run_C_solver_consistency_test(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]: for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
model = M() model = M()
...@@ -70,8 +69,9 @@ def _run_C_solver_consistency_test(rank=0): ...@@ -70,8 +69,9 @@ def _run_C_solver_consistency_test(rank=0):
@pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0") @pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
@rerun_if_address_is_in_use()
def test_C_solver_consistency(): def test_C_solver_consistency():
mp.spawn(_run_C_solver_consistency_test, nprocs=1) spawn(_run_C_solver_consistency_test, 1)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -4,7 +4,6 @@ from typing import Callable ...@@ -4,7 +4,6 @@ from typing import Callable
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
from torch.fx import GraphModule from torch.fx import GraphModule
...@@ -15,7 +14,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta ...@@ -15,7 +14,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor # from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta(): if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
...@@ -68,8 +67,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call ...@@ -68,8 +67,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
def _run_ckpt_solver(rank): def _run_ckpt_solver(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MODEL_LIST = [tm.densenet121] MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
...@@ -98,12 +97,13 @@ def _run_ckpt_solver(rank): ...@@ -98,12 +97,13 @@ def _run_ckpt_solver(rank):
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.") @pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
@rerun_if_address_is_in_use()
def test_ckpt_solver(): def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1) spawn(_run_ckpt_solver, 1)
def _run_ckpt_solver_torch11(rank): def _run_ckpt_solver_torch11(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MODEL_LIST = [tm.densenet121] MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
...@@ -131,8 +131,9 @@ def _run_ckpt_solver_torch11(rank): ...@@ -131,8 +131,9 @@ def _run_ckpt_solver_torch11(rank):
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
@rerun_if_address_is_in_use()
def test_ckpt_solver_torch11(): def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1) spawn(_run_ckpt_solver_torch11, 1)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -8,6 +8,7 @@ from colossalai.fx.graph_module import ColoGraphModule ...@@ -8,6 +8,7 @@ from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import linearize, solver_rotor # from colossalai.fx.passes.algorithms import linearize, solver_rotor
# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss) # from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.testing import clear_cache_before_run
if is_compatible_with_meta(): if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor from colossalai.fx.profiler.tensor import MetaTensor
...@@ -24,6 +25,7 @@ except: ...@@ -24,6 +25,7 @@ except:
@pytest.mark.skip(reason='TODO: modify the logger') @pytest.mark.skip(reason='TODO: modify the logger')
@pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
@clear_cache_before_run()
def test_linearize(): def test_linearize():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer() tracer = ColoTracer()
...@@ -84,6 +86,7 @@ def test_linearize(): ...@@ -84,6 +86,7 @@ def test_linearize():
@pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skip(reason="torch11 meta tensor not implemented") @pytest.mark.skip(reason="torch11 meta tensor not implemented")
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") @pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
@clear_cache_before_run()
def test_linearize_torch11(): def test_linearize_torch11():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer() tracer = ColoTracer()
......
import time import time
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
import colossalai import colossalai
...@@ -12,8 +10,8 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize ...@@ -12,8 +10,8 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port, get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_auto_parallel.test_offload.model_utils import *
from tests.test_tensor.common_utils import set_seed from tests.test_tensor.common_utils import set_seed
...@@ -140,9 +138,9 @@ def run_dist(rank, world_size, port): ...@@ -140,9 +138,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.skip("this test failed") @pytest.mark.skip("this test failed")
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
@rerun_if_address_is_in_use()
def test_perf(): def test_perf():
run_func = partial(run_dist, world_size=1, port=free_port()) spawn(run_dist, 1)
mp.spawn(run_func, nprocs=1)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -3,20 +3,20 @@ import torch.fx ...@@ -3,20 +3,20 @@ import torch.fx
from torch.fx import GraphModule from torch.fx import GraphModule
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.auto_parallel.offload.region_manager import RegionManager
from colossalai.auto_parallel.offload.solver import NOT_NVML, SolverFactory
from colossalai.fx import ColoTracer, is_compatible_with_meta from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.auto_parallel.offload.region_manager import RegionManager from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML
from colossalai.testing import parameterize
from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_auto_parallel.test_offload.model_utils import *
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
@clear_cache_before_run()
@parameterize('model_name', ['gpt2_', 'bert_']) @parameterize('model_name', ['gpt2_', 'bert_'])
@parameterize('memory_budget', [4000]) @parameterize('memory_budget', [4000])
@parameterize('solver_name', ['syn', 'asyn']) @parameterize('solver_name', ['syn', 'asyn'])
def solver_test(model_name: str, def solver_test(model_name: str, memory_budget: float, solver_name: str):
memory_budget: float,
solver_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen = get_components_func() model_builder, data_gen = get_components_func()
...@@ -52,11 +52,16 @@ def solver_test(model_name: str, ...@@ -52,11 +52,16 @@ def solver_test(model_name: str,
for region in region_list: for region in region_list:
need_offload = region.need_offload need_offload = region.need_offload
to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None
print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') print(
f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
)
for region in region_list.__reversed__(): for region in region_list.__reversed__():
need_offload = region.need_offload need_offload = region.need_offload
to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None
print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') print(
f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
)
if __name__ == '__main__': if __name__ == '__main__':
solver_test() solver_test()
\ No newline at end of file
...@@ -6,6 +6,7 @@ from colossalai.device.device_mesh import DeviceMesh ...@@ -6,6 +6,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import clear_cache_before_run
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
...@@ -26,6 +27,7 @@ def insert_narrow(gm, x_node): ...@@ -26,6 +27,7 @@ def insert_narrow(gm, x_node):
return gm return gm
@clear_cache_before_run()
def test_node_args_converting_pass(): def test_node_args_converting_pass():
model = TestModule() model = TestModule()
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
......
...@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer ...@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import clear_cache_before_run
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
...@@ -36,6 +37,7 @@ def recover_narrow(gm, narrow_node): ...@@ -36,6 +37,7 @@ def recover_narrow(gm, narrow_node):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_size_value_converting_pass(): def test_size_value_converting_pass():
model = TestModule() model = TestModule()
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
......
...@@ -2,7 +2,6 @@ from functools import partial ...@@ -2,7 +2,6 @@ from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
try: try:
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
...@@ -13,9 +12,7 @@ except: ...@@ -13,9 +12,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
class LinearModel(torch.nn.Module): class LinearModel(torch.nn.Module):
...@@ -86,11 +83,8 @@ def check_conv_module(rank, world_size, port): ...@@ -86,11 +83,8 @@ def check_conv_module(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_bias_addition_module(): def test_bias_addition_module():
world_size = 4 spawn(check_linear_module, 4)
run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port()) spawn(check_conv_module, 4)
mp.spawn(run_func_linear, nprocs=world_size)
run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port())
mp.spawn(run_func_conv, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from transformers.pytorch_utils import Conv1D from transformers.pytorch_utils import Conv1D
...@@ -17,9 +15,7 @@ except: ...@@ -17,9 +15,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
HIDDEN_SIZE = 16 HIDDEN_SIZE = 16
...@@ -65,9 +61,7 @@ def check_act_ckpt(rank, world_size, port): ...@@ -65,9 +61,7 @@ def check_act_ckpt(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_mlp_layer(): def test_mlp_layer():
world_size = 4 spawn(check_act_ckpt, 4)
run_func = partial(check_act_ckpt, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import copy import copy
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
try: try:
...@@ -15,9 +13,7 @@ except: ...@@ -15,9 +13,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
class MLP(torch.nn.Module): class MLP(torch.nn.Module):
...@@ -102,9 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port): ...@@ -102,9 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_compatibility_with_ddp(): def test_compatibility_with_ddp():
world_size = 4 spawn(check_compatibility_with_ddp, 4)
run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import copy import copy
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
try: try:
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
...@@ -17,10 +14,9 @@ from colossalai.initialize import launch ...@@ -17,10 +14,9 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.process_group import ProcessGroup from colossalai.tensor.process_group import ProcessGroup
from colossalai.testing import assert_close, rerun_if_address_is_in_use from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.utils import get_current_device
from colossalai.utils import free_port, get_current_device from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
class MLP(torch.nn.Module): class MLP(torch.nn.Module):
...@@ -110,9 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port): ...@@ -110,9 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_auto_parallel_with_gemini(): def test_auto_parallel_with_gemini():
world_size = 4 spawn(check_auto_parallel_with_gemini, 4)
run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -10,8 +10,7 @@ from colossalai._analyzer.fx.passes import shape_prop_pass ...@@ -10,8 +10,7 @@ from colossalai._analyzer.fx.passes import shape_prop_pass
# from colossalai.fx.tracer.tracer import ColoTracer # from colossalai.fx.tracer.tracer import ColoTracer
from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.testing import parameterize from colossalai.testing import clear_cache_before_run, parameterize, run_on_environment_flag
from colossalai.testing.pytest_wrapper import run_on_environment_flag
NUM_REPEAT_BLOCKS = 4 NUM_REPEAT_BLOCKS = 4
BATCH_SIZE = 1 BATCH_SIZE = 1
...@@ -81,6 +80,7 @@ class NonRepeatModel(nn.Module): ...@@ -81,6 +80,7 @@ class NonRepeatModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
@parameterize('model_cls', [RepeatModel, NonRepeatModel]) @parameterize('model_cls', [RepeatModel, NonRepeatModel])
def test_repeat_blocks(model_cls): def test_repeat_blocks(model_cls):
......
import copy import copy
import random import random
from functools import partial
from typing import Dict from typing import Dict
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import transformers import transformers
from torch.fx import GraphModule from torch.fx import GraphModule
...@@ -30,9 +28,8 @@ from colossalai.device.device_mesh import DeviceMesh ...@@ -30,9 +28,8 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import to_global from colossalai.tensor.shape_consistency import to_global
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
BATCH_SIZE = 1 BATCH_SIZE = 1
...@@ -190,9 +187,7 @@ def check_attention_layer(rank, model_cls, world_size, port): ...@@ -190,9 +187,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) @parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_mlp_layer(model_cls): def test_mlp_layer(model_cls):
world_size = 4 spawn(check_attention_layer, 4, model_cls=model_cls)
run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment