Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import pprint import pprint
from functools import partial
import pytest
import colossalai.nn as col_nn import torch
import pytest import torch.nn as nn
import torch
import torch.multiprocessing as mp import colossalai.nn as col_nn
import torch.nn as nn from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc
from colossalai.core import global_context as gpc from colossalai.initialize import launch
from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers
from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
from colossalai.utils import free_port, get_current_device, is_using_pp from colossalai.utils import is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model):
def build_pipeline(model): from colossalai.pipeline.utils import partition_uniform
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) depth = len(model)
depth = len(model) start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] layers = []
layers = [] for i in range(depth):
for i in range(depth): if start <= i < end:
if start <= i < end: layers.append(model[i])
layers.append(model[i]) else:
else: layers.append(nn.Identity())
layers.append(nn.Identity()) return nn.Sequential(*tuple(layers))
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
def check_equal(A, B): assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_3d(rank, world_size, port):
def check_checkpoint_3d(rank, world_size, port): config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),)
config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),)
disable_existing_loggers()
disable_existing_loggers() launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) sd1 = m1.state_dict()
sd1 = m1.state_dict() if gpc.get_global_rank() == 0:
if gpc.get_global_rank() == 0: print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") save_checkpoint("test.pt", 0, m1)
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) if is_using_pp():
if is_using_pp(): m2 = build_pipeline(m2)
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
load_checkpoint("test.pt", m2) sd2 = m2.state_dict()
sd2 = m2.state_dict() if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: sd2 = gather_pipeline_parallel_state_dict(sd2)
sd2 = gather_pipeline_parallel_state_dict(sd2) print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
if gpc.get_global_rank() == 0: for k, v in sd1.items():
for k, v in sd1.items(): assert k in sd2
assert k in sd2 check_equal(v, sd2[k].to(torch.device("cpu")))
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.dist @pytest.mark.skip("takes too long")
@pytest.mark.skip("takes too long") @skip_if_not_enough_gpus(min_gpus=8)
@skip_if_not_enough_gpus(min_gpus=8) @rerun_if_address_is_in_use()
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") def test_checkpoint_3d():
def test_checkpoint_3d(): spawn(check_checkpoint_3d, 8)
world_size = 8
run_func = partial(check_checkpoint_3d, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) if __name__ == "__main__":
test_checkpoint_3d()
if __name__ == "__main__":
test_checkpoint_3d()
...@@ -3,20 +3,19 @@ from functools import partial ...@@ -3,20 +3,19 @@ from functools import partial
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Dict from typing import Dict
import colossalai
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.checkpoint_io.io import load, save
from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta)
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.optim import Adam, Optimizer from torch.optim import Adam, Optimizer
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint_io.io import load, save
from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
assert set(a.keys()) == set(b.keys()) assert set(a.keys()) == set(b.keys())
...@@ -120,14 +119,13 @@ def test_save_global_load_global(max_shard_size_gb: float): ...@@ -120,14 +119,13 @@ def test_save_global_load_global(max_shard_size_gb: float):
check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict()) check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict())
def run_dist(rank, world_size, port, func): def run_dist(rank, world_size, port, test_fn):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
func() test_fn()
def launch_dist(fn, world_size: int): def launch_dist(fn, world_size: int):
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) spawn(run_dist, world_size, test_fn=fn)
mp.spawn(proc_fn, nprocs=world_size)
def save_dist(dir_name: str, zero: bool): def save_dist(dir_name: str, zero: bool):
......
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
from colossalai.utils.checkpoint_io.io import save, merge
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from tempfile import TemporaryDirectory
from torch.optim import Adam
from functools import partial
import torch
import os import os
from functools import partial
from tempfile import TemporaryDirectory
import pytest import pytest
import colossalai import torch
import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.nn as nn
from torch.optim import Adam
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
from colossalai.utils.checkpoint_io.io import merge, save
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
class DummyModel(nn.Module): class DummyModel(nn.Module):
...@@ -52,7 +52,7 @@ def test_merge_global(): ...@@ -52,7 +52,7 @@ def test_merge_global():
assert len(os.listdir(output_dir)) == 0 assert len(os.listdir(output_dir)) == 0
def run_dist(rank, world_size, port, func): def run_dist(rank, world_size, port, test_fn):
colossalai.launch(config={'parallel': { colossalai.launch(config={'parallel': {
'tensor': { 'tensor': {
'mode': '1d', 'mode': '1d',
...@@ -64,7 +64,7 @@ def run_dist(rank, world_size, port, func): ...@@ -64,7 +64,7 @@ def run_dist(rank, world_size, port, func):
host='localhost', host='localhost',
port=port, port=port,
backend='nccl') backend='nccl')
func() test_fn()
def run_save_dist(dir_name: str, zero: bool): def run_save_dist(dir_name: str, zero: bool):
...@@ -100,8 +100,7 @@ def test_merge_tp_dp(zero: bool): ...@@ -100,8 +100,7 @@ def test_merge_tp_dp(zero: bool):
with TemporaryDirectory() as dir_name: with TemporaryDirectory() as dir_name:
fn = partial(run_save_dist, dir_name, zero) fn = partial(run_save_dist, dir_name, zero)
world_size = 4 world_size = 4
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) spawn(run_dist, world_size, test_fn=fn)
mp.spawn(proc_fn, nprocs=world_size)
with TemporaryDirectory() as output_dir: with TemporaryDirectory() as output_dir:
merge(dir_name, output_dir) merge(dir_name, output_dir)
assert len(os.listdir(output_dir)) == 5 assert len(os.listdir(output_dir)) == 5
......
...@@ -2,19 +2,23 @@ import os ...@@ -2,19 +2,23 @@ import os
from functools import partial from functools import partial
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import colossalai
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.testing import rerun_if_address_is_in_use from torch.optim import Adam
from colossalai.utils import free_port
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
from colossalai.utils.checkpoint_io.io import redist, save from colossalai.utils.checkpoint_io.io import redist, save
from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, from colossalai.utils.checkpoint_io.meta import (
RedistMeta) ParamDistMeta,
from torch.optim import Adam ParamRedistMeta,
PipelineRedistMeta,
RankRedistMeta,
RedistMeta,
)
class DummyModel(nn.Module): class DummyModel(nn.Module):
...@@ -105,7 +109,7 @@ def test_global_to_dist(): ...@@ -105,7 +109,7 @@ def test_global_to_dist():
check_checkpoint_shape(output_dir) check_checkpoint_shape(output_dir)
def run_dist(rank, world_size, port, func): def run_dist(rank, world_size, port, test_fn):
colossalai.launch(config={'parallel': { colossalai.launch(config={'parallel': {
'tensor': { 'tensor': {
'mode': '1d', 'mode': '1d',
...@@ -117,7 +121,7 @@ def run_dist(rank, world_size, port, func): ...@@ -117,7 +121,7 @@ def run_dist(rank, world_size, port, func):
host='localhost', host='localhost',
port=port, port=port,
backend='nccl') backend='nccl')
func() test_fn()
def run_save_dist(dir_name: str, zero: bool): def run_save_dist(dir_name: str, zero: bool):
...@@ -133,8 +137,7 @@ def test_dist_to_dist(zero: bool): ...@@ -133,8 +137,7 @@ def test_dist_to_dist(zero: bool):
with TemporaryDirectory() as dir_name: with TemporaryDirectory() as dir_name:
fn = partial(run_save_dist, dir_name, zero) fn = partial(run_save_dist, dir_name, zero)
world_size = 4 world_size = 4
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) spawn(run_dist, world_size, test_fn=fn)
mp.spawn(proc_fn, nprocs=world_size)
with TemporaryDirectory() as output_dir: with TemporaryDirectory() as output_dir:
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
if not zero: if not zero:
......
...@@ -3,21 +3,24 @@ from functools import partial ...@@ -3,21 +3,24 @@ from functools import partial
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Dict from typing import Dict
import colossalai
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.checkpoint_io.constant import (GLOBAL_META_FILE_NAME, META_CKPT_FILE_NAME, MODEL_CKPT_FILE_NAME,
OTHER_CKPT_FILE_NAME)
from colossalai.utils.checkpoint_io.io import save
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
from torch import Tensor from torch import Tensor
from torch.optim import Adam from torch.optim import Adam
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint_io.constant import (
GLOBAL_META_FILE_NAME,
META_CKPT_FILE_NAME,
MODEL_CKPT_FILE_NAME,
OTHER_CKPT_FILE_NAME,
)
from colossalai.utils.checkpoint_io.io import save
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
assert set(a.keys()) == set(b.keys()) assert set(a.keys()) == set(b.keys())
...@@ -104,9 +107,9 @@ def test_save_global_shard(): ...@@ -104,9 +107,9 @@ def test_save_global_shard():
}) })
def run_dist(rank, world_size, port, func): def run_dist(rank, world_size, port, test_fn):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
func() test_fn()
def run_save_dist(dir_name): def run_save_dist(dir_name):
...@@ -124,8 +127,7 @@ def test_save_dist(): ...@@ -124,8 +127,7 @@ def test_save_dist():
with TemporaryDirectory() as dir_name: with TemporaryDirectory() as dir_name:
fn = partial(run_save_dist, dir_name) fn = partial(run_save_dist, dir_name)
world_size = 2 world_size = 2
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) spawn(run_dist, world_size, test_fn=fn)
mp.spawn(proc_fn, nprocs=world_size)
assert len(os.listdir(dir_name)) == 8 assert len(os.listdir(dir_name)) == 8
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
assert len(global_meta['meta']) == 2 assert len(global_meta['meta']) == 2
......
import os import os
import shutil import shutil
from copy import deepcopy from copy import deepcopy
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR
import colossalai import colossalai
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext from colossalai.zero import ColoInitContext
...@@ -202,13 +199,7 @@ def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): ...@@ -202,13 +199,7 @@ def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) # @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None): def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
run_func = partial(run_dist, spawn(run_dist, world_size, use_ddp=use_ddp, use_mp_reload=use_mp_reload, test_scheduler=test_scheduler)
world_size=world_size,
port=free_port(),
use_ddp=use_ddp,
use_mp_reload=use_mp_reload,
test_scheduler=test_scheduler)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.zero.legacy.sharded_param import ShardedTensor from colossalai.zero.legacy.sharded_param import ShardedTensor
def run_tensor_move(rank): def run_tensor_move(rank, world_size, port):
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl')
src_t = torch.ones(2, 3).cuda() src_t = torch.ones(2, 3).cuda()
tgt_t = torch.zeros(2, 3) tgt_t = torch.zeros(2, 3)
...@@ -36,7 +34,7 @@ def run_tensor_move(rank): ...@@ -36,7 +34,7 @@ def run_tensor_move(rank):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_tensor_move(): def test_tensor_move():
mp.spawn(run_tensor_move, nprocs=1) spawn(run_tensor_move, 1)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from einops import rearrange from einops import rearrange
from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN
from colossalai.testing import clear_cache_before_run, parameterize
if HAS_MEM_EFF_ATTN: if HAS_MEM_EFF_ATTN:
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
...@@ -22,7 +23,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): ...@@ -22,7 +23,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) @clear_cache_before_run()
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
D = H * D_HEAD D = H * D_HEAD
...@@ -42,7 +44,8 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): ...@@ -42,7 +44,8 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) @clear_cache_before_run()
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
D = H * D_HEAD D = H * D_HEAD
...@@ -65,7 +68,8 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): ...@@ -65,7 +68,8 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) @clear_cache_before_run()
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
D = H * D_HEAD D = H * D_HEAD
...@@ -84,7 +88,8 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): ...@@ -84,7 +88,8 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) @clear_cache_before_run()
@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16):
D = H * D_HEAD D = H * D_HEAD
......
from functools import partial
from typing import Optional from typing import Optional
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import colossalai import colossalai
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.utils.common import print_rank_0 from colossalai.utils.common import print_rank_0
try: try:
...@@ -105,9 +102,7 @@ def run_dist(rank, world_size, port) -> None: ...@@ -105,9 +102,7 @@ def run_dist(rank, world_size, port) -> None:
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_dist_lazy_init(): def test_dist_lazy_init():
world_size = 4 spawn(run_dist, 4)
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import pytest import pytest
import colossalai import colossalai
from colossalai.testing import spawn
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.utils import free_port
from functools import partial
import torch.multiprocessing as mp
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
...@@ -24,8 +21,7 @@ def run_dist(rank, world_size, port): ...@@ -24,8 +21,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [3, 4]) @pytest.mark.parametrize("world_size", [3, 4])
def test_memory_utils(world_size): def test_memory_utils(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from colossalai.tensor import distspec, ColoTensorSpec, ProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter
import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp from torch.nn.parameter import Parameter
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port, get_current_device
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from functools import partial
from colossalai.testing import parameterize, rerun_if_address_is_in_use import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.utils.common import clip_grad_norm from colossalai.utils.common import clip_grad_norm
from torch.nn.parameter import Parameter
def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
...@@ -71,8 +70,7 @@ def run_dist(rank, world_size, port): ...@@ -71,8 +70,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_zero_clip_grad(world_size: int): def test_zero_clip_grad(world_size: int):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import copy
from functools import partial from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port from colossalai.utils import checkpoint, clip_grad_norm_fp32
from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
...@@ -106,8 +104,7 @@ def run_dist(rank, world_size, port): ...@@ -106,8 +104,7 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_zero_clip_grad(): def test_zero_clip_grad():
world_size = 4 world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.zero.gemini.chunk import ChunkManager from colossalai.zero.gemini.chunk import ChunkManager
from tests.test_tensor.common_utils import debug_print from tests.test_tensor.common_utils import debug_print
...@@ -64,8 +60,7 @@ def run_dist(rank, world_size, port): ...@@ -64,8 +60,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_chunk_manager(world_size): def test_chunk_manager(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port, get_current_device from colossalai.utils import get_current_device
from colossalai.zero.gemini import TensorState from colossalai.zero.gemini import TensorState
from colossalai.zero.gemini.chunk import Chunk from colossalai.zero.gemini.chunk import Chunk
...@@ -117,8 +114,7 @@ def run_dist(rank, world_size, port): ...@@ -117,8 +114,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 2, 4]) @pytest.mark.parametrize('world_size', [1, 2, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_chunk_function(world_size): def test_chunk_function(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
...@@ -10,8 +7,7 @@ import colossalai ...@@ -10,8 +7,7 @@ import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
...@@ -103,8 +99,7 @@ def run_dist(rank, world_size, port): ...@@ -103,8 +99,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gpt(world_size): def test_gpt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager from colossalai.zero.gemini.gemini_mgr import GeminiManager
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
...@@ -98,8 +94,7 @@ def run_dist(rank, world_size, port): ...@@ -98,8 +94,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gemini_use_rmt(world_size): def test_gemini_use_rmt(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import os
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, GeminiDDP from colossalai.zero import ColoInitContext, GeminiDDP
from colossalai.zero.gemini.utils import get_static_torch_model from colossalai.zero.gemini.utils import get_static_torch_model
...@@ -50,8 +45,7 @@ def run_dist(rank, world_size, port): ...@@ -50,8 +45,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_convert_torch_module(world_size): def test_convert_torch_module(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
from time import time
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test import run_fwd_bwd from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed from tests.test_tensor.common_utils import set_seed
def check_param(model: ZeroDDP, torch_model: torch.nn.Module): def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
...@@ -105,8 +100,7 @@ def run_dist(rank, world_size, port): ...@@ -105,8 +100,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_grad_clip(world_size): def test_grad_clip(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
from typing import Callable from typing import Callable
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper
from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
...@@ -128,8 +125,7 @@ def run_dist(rank, world_size, port): ...@@ -128,8 +125,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_inference(world_size): def test_inference(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter, ColoTensor from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx
from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
...@@ -157,8 +152,7 @@ def run_dist(rank, world_size, port): ...@@ -157,8 +152,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_optim(world_size): def test_optim(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment