"applications/vscode:/vscode.git/clone" did not exist on "46c009dba462800b4f7bd54b4558a37e6326726d"
Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
from functools import partial
import torch import torch
from colossalai.fx.tracer.meta_patch import patched_function from colossalai.fx.tracer.meta_patch import patched_function
from functools import partial from colossalai.testing import clear_cache_before_run
def _run(data, patch_fn): def _run(data, patch_fn):
...@@ -22,6 +25,7 @@ def _assert_output_shape(data, patch_fn, expect_exception, output_shape): ...@@ -22,6 +25,7 @@ def _assert_output_shape(data, patch_fn, expect_exception, output_shape):
assert output.shape == output_shape assert output.shape == output_shape
@clear_cache_before_run()
def test_repeat_interleave(): def test_repeat_interleave():
patch_fn = patched_function.torch_repeat_interleave patch_fn = patched_function.torch_repeat_interleave
...@@ -63,6 +67,7 @@ def test_repeat_interleave(): ...@@ -63,6 +67,7 @@ def test_repeat_interleave():
output_shape=materialized_output.shape) output_shape=materialized_output.shape)
@clear_cache_before_run()
def test_torch_max(): def test_torch_max():
data = torch.rand(4, 3) data = torch.rand(4, 3)
out = torch.max(data) out = torch.max(data)
......
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from packaging import version from packaging import version
from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace
from colossalai.testing import clear_cache_before_run
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
...@@ -43,6 +44,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): ...@@ -43,6 +44,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
def test_timm_models(): def test_timm_models():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
......
...@@ -3,12 +3,14 @@ import torch ...@@ -3,12 +3,14 @@ import torch
from packaging import version from packaging import version
from torchaudio_utils import trace_and_compare from torchaudio_utils import trace_and_compare
from colossalai.testing import clear_cache_before_run
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
# We cannot handle the tensors constructed with constant during forward, such as ``torch.empty(0).to(device=Proxy.device)`` # We cannot handle the tensors constructed with constant during forward, such as ``torch.empty(0).to(device=Proxy.device)``
# TODO: We could handle this case by hijacking torch.Tensor.to function. # TODO: We could handle this case by hijacking torch.Tensor.to function.
@pytest.mark.skip @pytest.mark.skip
@clear_cache_before_run()
def test_torchaudio_models(): def test_torchaudio_models():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
......
...@@ -2,6 +2,7 @@ import pytest ...@@ -2,6 +2,7 @@ import pytest
import torch import torch
from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace
from colossalai.testing import clear_cache_before_run
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
BATCH = 2 BATCH = 2
...@@ -47,6 +48,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): ...@@ -47,6 +48,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
@clear_cache_before_run()
def test_torchrec_deepfm_models(): def test_torchrec_deepfm_models():
deepfm_models = model_zoo.get_sub_registry('deepfm') deepfm_models = model_zoo.get_sub_registry('deepfm')
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
......
...@@ -2,6 +2,7 @@ import pytest ...@@ -2,6 +2,7 @@ import pytest
import torch import torch
from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace
from colossalai.testing import clear_cache_before_run
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
BATCH = 2 BATCH = 2
...@@ -47,6 +48,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): ...@@ -47,6 +48,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
@clear_cache_before_run()
def test_torchrec_dlrm_models(): def test_torchrec_dlrm_models():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
dlrm_models = model_zoo.get_sub_registry('dlrm') dlrm_models = model_zoo.get_sub_registry('dlrm')
......
import torch import torch
from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace
from colossalai.testing import clear_cache_before_run
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
@clear_cache_before_run()
def test_torchvision_models(): def test_torchvision_models():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
tv_sub_registry = model_zoo.get_sub_registry('torchvision') tv_sub_registry = model_zoo.get_sub_registry('torchvision')
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
from checks_1d.check_layer_1d import * from checks_1d.check_layer_1d import *
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),)
...@@ -40,9 +36,7 @@ def check_layer(rank, world_size, port): ...@@ -40,9 +36,7 @@ def check_layer(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_1d(): def test_1d():
world_size = 4 spawn(check_layer, 4)
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp from checks_2d.check_layer_2d import (
check_classifier_given_embed_weight,
check_classifier_no_given_weight,
check_embed,
check_layernorm,
check_linear,
check_loss,
check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
check_vocab_parallel_classifier_no_given_weight,
check_vocab_parallel_embed,
check_vocab_parallel_loss,
)
from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import rerun_if_address_is_in_use
from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed,
check_vocab_parallel_loss)
from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),) CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),)
...@@ -57,9 +62,7 @@ def check_layer_and_operation(rank, world_size, port): ...@@ -57,9 +62,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_2d(): def test_2d():
world_size = 4 spawn(check_layer_and_operation, 4)
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp from checks_2p5d.check_layer_2p5d import *
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import rerun_if_address_is_in_use
from checks_2p5d.check_layer_2p5d import *
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
CONFIG = dict(parallel=dict( CONFIG = dict(parallel=dict(
pipeline=dict(size=1), pipeline=dict(size=1),
...@@ -53,9 +50,7 @@ def check_layer_and_operation(rank, world_size, port): ...@@ -53,9 +50,7 @@ def check_layer_and_operation(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_2p5d(): def test_2p5d():
world_size = 4 spawn(check_layer_and_operation, 4)
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp from checks_3d.check_layer_3d import (
check_classifier_no_given_weight,
check_embed,
check_layernorm,
check_linear,
check_loss,
check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
check_vocab_parallel_classifier_no_given_weight,
check_vocab_parallel_embed,
check_vocab_parallel_loss,
)
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus
from checks_3d.check_layer_3d import (check_classifier_no_given_weight, check_embed, check_layernorm, check_linear,
check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight,
check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed,
check_vocab_parallel_loss)
CONFIG = dict( CONFIG = dict(
parallel=dict( parallel=dict(
...@@ -52,9 +57,7 @@ def check_layer_and_operation(rank, world_size, port): ...@@ -52,9 +57,7 @@ def check_layer_and_operation(rank, world_size, port):
@skip_if_not_enough_gpus(min_gpus=8) @skip_if_not_enough_gpus(min_gpus=8)
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_3d(): def test_3d():
world_size = 8 spawn(check_layer_and_operation, 8)
run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import pytest
from functools import partial
import numpy as np
import random import random
from typing import List
import numpy as np
import pytest
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.utils import free_port from colossalai.nn.parallel.layers import (
from colossalai.testing import rerun_if_address_is_in_use CachedEmbeddingBag,
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ CachedParamMgr,
ColoTensor, ColoTensorSpec EvictionStrategy,
from colossalai.nn.parallel.layers import CachedParamMgr, CachedEmbeddingBag, ParallelCachedEmbeddingBag, EvictionStrategy, \ ParallelCachedEmbeddingBag,
ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig ParallelCachedEmbeddingBagTablewise,
from typing import List TablewiseEmbeddingBagConfig,
)
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
NUM_EMBED, EMBED_DIM = 10, 8 NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE = 8 BATCH_SIZE = 8
...@@ -44,6 +45,7 @@ def synthesize_1d_sparse_feature( ...@@ -44,6 +45,7 @@ def synthesize_1d_sparse_feature(
@pytest.mark.skip @pytest.mark.skip
@clear_cache_before_run()
def test_cachemgr(): def test_cachemgr():
model = torch.nn.EmbeddingBag(10000, 128) model = torch.nn.EmbeddingBag(10000, 128)
# 10 chunks, 5 in cuda # 10 chunks, 5 in cuda
...@@ -72,6 +74,7 @@ def test_cachemgr(): ...@@ -72,6 +74,7 @@ def test_cachemgr():
assert mgr.cuda_available_chunk_num == 5 assert mgr.cuda_available_chunk_num == 5
@clear_cache_before_run()
def test_reorder_with_freq(): def test_reorder_with_freq():
num_embed = 100 num_embed = 100
chunk_size = 1 chunk_size = 1
...@@ -102,7 +105,8 @@ def test_reorder_with_freq(): ...@@ -102,7 +105,8 @@ def test_reorder_with_freq():
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
@pytest.mark.parametrize('use_LFU', [True, False]) @clear_cache_before_run()
@parameterize('use_LFU', [True, False])
def test_freq_aware_embed(use_LFU: bool): def test_freq_aware_embed(use_LFU: bool):
device = torch.device('cuda', 0) device = torch.device('cuda', 0)
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
...@@ -148,7 +152,8 @@ def test_freq_aware_embed(use_LFU: bool): ...@@ -148,7 +152,8 @@ def test_freq_aware_embed(use_LFU: bool):
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
@pytest.mark.parametrize('init_freq', [True, False]) @clear_cache_before_run()
@parameterize('init_freq', [True, False])
def test_lfu_strategy(init_freq: bool): def test_lfu_strategy(init_freq: bool):
# minimal test to check behavior # minimal test to check behavior
Bag = CachedEmbeddingBag(5, Bag = CachedEmbeddingBag(5,
...@@ -248,7 +253,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): ...@@ -248,7 +253,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
input0 [1,2,3] [6,7] [] input0 [1,2,3] [6,7] []
input1 [] [9] [13,15] input1 [] [9] [13,15]
input2 [1,5] [6,8] [11] input2 [1,5] [6,8] [11]
↑ ↑ ↑ ↑ ↑ ↑
rank 0 rank 0 rank 1 rank 0 rank 0 rank 1
in KJT format in KJT format
''' '''
...@@ -363,8 +368,7 @@ def run_dist(rank, world_size, port): ...@@ -363,8 +368,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_parallel_freq_aware_embed(world_size): def test_parallel_freq_aware_embed(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import colossalai import pytest
import colossalai.nn as col_nn
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
import pytest
from colossalai.core import global_context as gpc import colossalai
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.testing import rerun_if_address_is_in_use from colossalai.core import global_context as gpc
from functools import partial from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence'))) CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
...@@ -121,8 +118,8 @@ def check_ring_av(rank, world_size): ...@@ -121,8 +118,8 @@ def check_ring_av(rank, world_size):
'attention output cannot match' 'attention output cannot match'
def run_test(rank, world_size): def run_test(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=29500) colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port)
# check_ring_qk(rank, world_size) # check_ring_qk(rank, world_size)
check_ring_av(rank, world_size) check_ring_av(rank, world_size)
...@@ -134,9 +131,7 @@ def run_test(rank, world_size): ...@@ -134,9 +131,7 @@ def run_test(rank, world_size):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_sequence(): def test_sequence():
world_size = 4 spawn(run_test, 4)
run_func = partial(run_test, world_size=world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
import colossalai import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.utils.moe import sync_moe_model_param
BATCH_SIZE = 4 BATCH_SIZE = 4
DIM = 16 DIM = 16
...@@ -65,9 +64,7 @@ def run_test(rank, world_size, port): ...@@ -65,9 +64,7 @@ def run_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_grad_handler(): def test_grad_handler():
world_size = 4 spawn(run_test, 4)
run_func = partial(run_test, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.testing import rerun_if_address_is_in_use from colossalai.core import global_context as gpc
from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
BATCH_SIZE = 16 BATCH_SIZE = 16
NUM_EXPERTS = 4 NUM_EXPERTS = 4
...@@ -90,15 +89,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f ...@@ -90,15 +89,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@pytest.mark.parametrize("router", [Top1Router, Top2Router]) @pytest.mark.parametrize("router", [Top1Router, Top2Router])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_kernel(rs, hidden_size, data_type, router): def test_moe_kernel(rs, hidden_size, data_type, router):
world_size = 4 spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router)
run_func = partial(run_routing,
world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type,
router=router)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import os import os
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.context import MOE_CONTEXT from colossalai.context import MOE_CONTEXT
from colossalai.nn.layer.moe import load_moe_model, save_moe_model from colossalai.nn.layer.moe import load_moe_model, save_moe_model
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port, get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext from colossalai.zero import ColoInitContext
from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_moe.test_moe_zero_init import MoeModel
from tests.test_tensor.common_utils import debug_print
from tests.test_zero.test_legacy.common import CONFIG from tests.test_zero.test_legacy.common import CONFIG
...@@ -46,8 +43,7 @@ def _run_dist(rank, world_size, port): ...@@ -46,8 +43,7 @@ def _run_dist(rank, world_size, port):
@pytest.mark.parametrize("world_size", [2, 4]) @pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size): def test_moe_checkpoint(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port()) spawn(_run_dist)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.context import MOE_CONTEXT from colossalai.context import MOE_CONTEXT
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port, get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext from colossalai.zero import ColoInitContext
from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_moe.test_moe_zero_init import MoeModel
from tests.test_tensor.common_utils import debug_print from tests.test_tensor.common_utils import debug_print
...@@ -52,8 +49,7 @@ def _run_dist(rank, world_size, port): ...@@ -52,8 +49,7 @@ def _run_dist(rank, world_size, port):
@pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_colo_init(world_size): def test_moe_colo_init(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port()) spawn(_run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
import colossalai import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Experts
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.nn.layer.moe import Experts
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.utils.moe import sync_moe_model_param from colossalai.utils.moe import sync_moe_model_param
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use
D_MODEL = 4 D_MODEL = 4
D_FF = 8 D_FF = 8
CONFIG = dict() CONFIG = dict()
def run_test(rank, port): def run_test(rank, world_size, port):
world_size = 4 world_size = 4
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
expert_module = nn.Linear expert_module = nn.Linear
...@@ -62,9 +61,7 @@ def run_test(rank, port): ...@@ -62,9 +61,7 @@ def run_test(rank, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_initialization(): def test_moe_initialization():
world_size = 4 spawn(run_test, 4)
run_func = partial(run_test, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import colossalai import colossalai
...@@ -10,8 +7,8 @@ from colossalai.context import MOE_CONTEXT ...@@ -10,8 +7,8 @@ from colossalai.context import MOE_CONTEXT
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import CheckpointModule from colossalai.nn import CheckpointModule
from colossalai.nn.layer import MoeModule from colossalai.nn.layer import MoeModule
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port, get_current_device from colossalai.utils import get_current_device
from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.init_ctx import ZeroInitContext
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
from tests.test_zero.test_legacy.common import CONFIG from tests.test_zero.test_legacy.common import CONFIG
...@@ -104,8 +101,7 @@ def _run_dist(rank, world_size, port): ...@@ -104,8 +101,7 @@ def _run_dist(rank, world_size, port):
@pytest.mark.parametrize("world_size", [2, 4]) @pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_zero_init(world_size): def test_moe_zero_init(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port()) spawn(_run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.context import MOE_CONTEXT from colossalai.context import MOE_CONTEXT
from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.nn import MoeLoss from colossalai.nn import MoeLoss
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.init_ctx import ZeroInitContext
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
from colossalai.zero.legacy.sharded_model import ShardedModelV2 from colossalai.zero.legacy.sharded_model import ShardedModelV2
...@@ -67,8 +63,7 @@ def run_dist(rank, world_size, port): ...@@ -67,8 +63,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_zero_model(world_size): def test_moe_zero_model(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
...@@ -10,8 +7,8 @@ from colossalai.context import MOE_CONTEXT ...@@ -10,8 +7,8 @@ from colossalai.context import MOE_CONTEXT
from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.nn import MoeLoss from colossalai.nn import MoeLoss
from colossalai.nn.optimizer import CPUAdam from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port, get_current_device from colossalai.utils import get_current_device
from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.init_ctx import ZeroInitContext
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
from colossalai.zero.legacy.sharded_model import ShardedModelV2 from colossalai.zero.legacy.sharded_model import ShardedModelV2
...@@ -116,8 +113,7 @@ def _run_dist(rank, world_size, port): ...@@ -116,8 +113,7 @@ def _run_dist(rank, world_size, port):
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size): def test_moe_zero_optim(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port()) spawn(_run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment