Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
import colossalai
import torch
import pytest import pytest
import torch
import torch.nn as nn import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor, ProcessGroup import colossalai
from colossalai.tensor import ColoTensorSpec from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
from functools import partial
from tests.test_tensor.common_utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d
class Conv1D(nn.Module): class Conv1D(nn.Module):
...@@ -69,8 +66,7 @@ def run_dist(rank, world_size, port): ...@@ -69,8 +66,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_addmm_1d(world_size): def test_addmm_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import pytest
import torch
from torch.nn import functional as F from torch.nn import functional as F
from functools import partial
import colossalai import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, tensor_equal, tensor_shard_equal
def run_with_spec(spec_init_func): def run_with_spec(spec_init_func):
...@@ -39,8 +36,7 @@ def run_dist(rank, world_size, port): ...@@ -39,8 +36,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_embedding_bag_1d(world_size): def test_embedding_bag_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import pytest
import torch
from torch.nn import functional as F from torch.nn import functional as F
from functools import partial
import colossalai import colossalai
import pytest from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
import torch from colossalai.testing import rerun_if_address_is_in_use, spawn
import torch.multiprocessing as mp from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
def run_with_spec(spec_init_func, pg: ProcessGroup): def run_with_spec(spec_init_func, pg: ProcessGroup):
...@@ -40,8 +37,7 @@ def run_dist(rank, world_size, port): ...@@ -40,8 +37,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_embedding_1d(world_size): def test_embedding_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port import colossalai
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
def run_with_spec(spec_init_func, split_bias): def run_with_spec(spec_init_func, split_bias):
...@@ -44,8 +41,7 @@ def run_dist(rank, world_size, port): ...@@ -44,8 +41,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear_1d(world_size): def test_linear_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import torch import pytest
import pytest import torch
import colossalai import torch.nn.functional as F
import torch.nn.functional as F
import torch.multiprocessing as mp import colossalai
from functools import partial from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern def check_cross_entropy():
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
def check_cross_entropy(): with torch.no_grad():
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True) input_ct.copy_(input_t)
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
with torch.no_grad(): target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())
input_ct.copy_(input_t)
world_size = torch.distributed.get_world_size()
target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device()) pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
world_size = torch.distributed.get_world_size() input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
pg = ProcessGroup(tp_degree=world_size) input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()])) output = F.cross_entropy(input_t, target)
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) output_colo = F.cross_entropy(input_shard, target)
assert torch.allclose(output_colo, output)
output = F.cross_entropy(input_t, target)
output_colo = F.cross_entropy(input_shard, target) output.backward()
assert torch.allclose(output_colo, output) output_colo.backward()
output.backward() assert torch.allclose(input_t.grad, input_ct.grad)
output_colo.backward()
assert torch.allclose(input_t.grad, input_ct.grad) def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_cross_entropy()
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_cross_entropy() @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
@pytest.mark.dist def test_loss_func(world_size):
@pytest.mark.parametrize('world_size', [1, 2]) spawn(run_dist, world_size)
@rerun_if_address_is_in_use()
def test_loss_func(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) if __name__ == '__main__':
mp.spawn(run_func, nprocs=world_size) test_loss_func(1)
if __name__ == '__main__':
test_loss_func(1)
import torch
import pytest import pytest
import colossalai import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.multiprocessing as mp
from functools import partial
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec
from colossalai.utils import get_current_device
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
def _run_layer_norm(): def _run_layer_norm():
...@@ -66,8 +64,7 @@ def run_dist(rank, world_size, port): ...@@ -66,8 +64,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_element_wise_ops(world_size): def test_element_wise_ops(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
def run_dist2(rank, world_size, port): def run_dist2(rank, world_size, port):
...@@ -79,8 +76,7 @@ def run_dist2(rank, world_size, port): ...@@ -79,8 +76,7 @@ def run_dist2(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1]) @pytest.mark.parametrize('world_size', [1])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_ln(world_size): def test_ln(world_size):
run_func = partial(run_dist2, world_size=world_size, port=free_port()) spawn(run_dist2, world_size)
mp.spawn(run_func, nprocs=world_size)
def check_all(): def check_all():
......
from functools import partial import pytest
import torch
import colossalai import torch.distributed as dist
import pytest
import torch import colossalai
import torch.multiprocessing as mp from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
import torch.distributed as dist from colossalai.tensor.distspec import DistPlacementPattern
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port, get_current_device from colossalai.utils import get_current_device
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor, ShardSpec from tests.test_tensor.common_utils import debug_print, split_param_col_tp1d, split_param_row_tp1d
from colossalai.tensor.distspec import DistPlacementPattern
from tests.test_tensor.common_utils import split_param_row_tp1d, split_param_col_tp1d, debug_print
def exam_view_core(pg):
# the case of replicated ColoTensors
def exam_view_core(pg): x = torch.randn(4, 4).cuda()
# the case of replicated ColoTensors x_colo = ColoTensor(x, ColoTensorSpec(pg))
x = torch.randn(4, 4).cuda()
x_colo = ColoTensor(x, ColoTensorSpec(pg)) y = x.view(2, -1, 2)
y_colo = x_colo.view(2, -1, 2)
y = x.view(2, -1, 2)
y_colo = x_colo.view(2, -1, 2) assert torch.all(y == y_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
assert torch.all(y == y_colo) # the perfect case of col-sliced ColoTensors
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE split_param_col_tp1d(x_colo, pg)
# the perfect case of col-sliced ColoTensors
split_param_col_tp1d(x_colo, pg) z = x.view(torch.Size((2, 1, 2, -1)))
z_colo = x_colo.view(torch.Size((2, 1, 2, -1)))
z = x.view(torch.Size((2, 1, 2, -1))) if dist.get_rank() == 0:
z_colo = x_colo.view(torch.Size((2, 1, 2, -1))) z = z[:, :, :, 0:2]
if dist.get_rank() == 0: else:
z = z[:, :, :, 0:2] z = z[:, :, :, 2:]
else: assert torch.all(z == z_colo)
z = z[:, :, :, 2:] assert z_colo.dist_spec == x_colo.dist_spec
assert torch.all(z == z_colo) # the perfect case of row-sliced ColoTensors
assert z_colo.dist_spec == x_colo.dist_spec split_param_row_tp1d(x_colo, pg)
# the perfect case of row-sliced ColoTensors
split_param_row_tp1d(x_colo, pg) z = x.view(torch.Size((-1, 2, 2)))
z_colo = x_colo.view(torch.Size((-1, 2, 2)))
z = x.view(torch.Size((-1, 2, 2))) if dist.get_rank() == 0:
z_colo = x_colo.view(torch.Size((-1, 2, 2))) z = z[0:2, :, :]
if dist.get_rank() == 0: else:
z = z[0:2, :, :] z = z[2:, :, :]
else: assert torch.all(z == z_colo)
z = z[2:, :, :] assert z_colo.dist_spec == x_colo.dist_spec
assert torch.all(z == z_colo) # the normal case of row-sliced ColoTensors
assert z_colo.dist_spec == x_colo.dist_spec z = x.view(-1, 2, 2, 2)
# the normal case of row-sliced ColoTensors z_colo = x_colo.view(-1, 2, 2, 2)
z = x.view(-1, 2, 2, 2) assert torch.all(z == z_colo)
z_colo = x_colo.view(-1, 2, 2, 2) assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
assert torch.all(z == z_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
def exam_view_autograd(pg):
x = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
def exam_view_autograd(pg): y = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
x = torch.randn(8, 2, device=get_current_device(), requires_grad=True) with torch.no_grad():
y = torch.randn(8, 2, device=get_current_device(), requires_grad=True) y.copy_(x)
with torch.no_grad(): y = ColoTensor(y, ColoTensorSpec(pg))
y.copy_(x) y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
y = ColoTensor(y, ColoTensorSpec(pg))
y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()])) xx = x.view(2, 2, -1)
yy_slice = y_slice.view(2, 2, -1)
xx = x.view(2, 2, -1) yy = yy_slice.to_replicate()
yy_slice = y_slice.view(2, 2, -1) grad = torch.randn(2, 2, 4, device=get_current_device())
yy = yy_slice.to_replicate()
grad = torch.randn(2, 2, 4, device=get_current_device()) xx.backward(grad)
yy.backward(grad)
xx.backward(grad) assert torch.all(x.grad == y.grad)
yy.backward(grad)
assert torch.all(x.grad == y.grad)
def exam_view_errors(pg):
x = torch.randn(8, 2, device=get_current_device())
def exam_view_errors(pg): x = ColoTensor(x, ColoTensorSpec(pg))
x = torch.randn(8, 2, device=get_current_device()) split_param_row_tp1d(x, pg)
x = ColoTensor(x, ColoTensorSpec(pg))
split_param_row_tp1d(x, pg) x.view('a', 'b', 'c')
x.view(8, -1)
x.view('a', 'b', 'c') x.view([-2, -2, -2])
x.view(8, -1) x.view((-1, -1, -1))
x.view([-2, -2, -2])
x.view((-1, -1, -1))
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
def run_dist(rank, world_size, port): pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') exam_view_core(pg)
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) exam_view_autograd(pg)
exam_view_core(pg) # exam_view_errors(pg)
exam_view_autograd(pg)
# exam_view_errors(pg)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.dist @rerun_if_address_is_in_use()
@pytest.mark.parametrize('world_size', [2]) def test_view(world_size):
@rerun_if_address_is_in_use() spawn(run_dist, world_size)
def test_view(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__':
test_view(2)
if __name__ == '__main__':
test_view(2)
...@@ -2,7 +2,7 @@ import math ...@@ -2,7 +2,7 @@ import math
import torch import torch
from colossalai.testing import parameterize from colossalai.testing import clear_cache_before_run, parameterize
def torch_adam_update( def torch_adam_update(
...@@ -46,6 +46,7 @@ def assertTrue(condition, msg): ...@@ -46,6 +46,7 @@ def assertTrue(condition, msg):
assert condition, msg assert condition, msg
@clear_cache_before_run()
@parameterize('adamw', [True, False]) @parameterize('adamw', [True, False])
@parameterize('step', [1, 2]) @parameterize('step', [1, 2])
@parameterize('p_dtype', [torch.float, torch.half]) @parameterize('p_dtype', [torch.float, torch.half])
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.optim.adam import Adam
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.adam import Adam
from colossalai.nn.optimizer.fused_adam import FusedAdam from colossalai.nn.optimizer.fused_adam import FusedAdam
from colossalai.testing import parameterize from colossalai.testing import clear_cache_before_run, parameterize
class FC(nn.Module): class FC(nn.Module):
...@@ -17,6 +17,7 @@ class FC(nn.Module): ...@@ -17,6 +17,7 @@ class FC(nn.Module):
return self.fc(x) return self.fc(x)
@clear_cache_before_run()
@parameterize('adamw', [False, True]) @parameterize('adamw', [False, True])
@parameterize('p_dtype', [torch.float, torch.half]) @parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half]) @parameterize('g_dtype', [torch.float, torch.half])
......
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from numpy import dtype from numpy import dtype
from colossalai.testing import parameterize from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.utils import multi_tensor_applier from colossalai.utils import multi_tensor_applier
...@@ -41,6 +41,7 @@ def torch_adam_update( ...@@ -41,6 +41,7 @@ def torch_adam_update(
param.addcdiv_(exp_avg, denom, value=-step_size) param.addcdiv_(exp_avg, denom, value=-step_size)
@clear_cache_before_run()
@parameterize('adamw', [False, True]) @parameterize('adamw', [False, True])
@parameterize('step', [1, 2]) @parameterize('step', [1, 2])
@parameterize('p_dtype', [torch.float, torch.half]) @parameterize('p_dtype', [torch.float, torch.half])
......
...@@ -4,11 +4,12 @@ from torch.optim import AdamW ...@@ -4,11 +4,12 @@ from torch.optim import AdamW
from torch.optim.adam import Adam from torch.optim.adam import Adam
from colossalai.nn.optimizer.hybrid_adam import HybridAdam from colossalai.nn.optimizer.hybrid_adam import HybridAdam
from colossalai.testing import parameterize from colossalai.testing import clear_cache_before_run, parameterize
RE = 3 RE = 3
@clear_cache_before_run()
@parameterize('adamw', [False, True]) @parameterize('adamw', [False, True])
@parameterize('device', ['cpu', 'cuda:0']) @parameterize('device', ['cpu', 'cuda:0'])
@parameterize('p_dtype', [torch.float]) @parameterize('p_dtype', [torch.float])
......
import pytest import pytest
import torch import torch
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.testing import clear_cache_before_run, parameterize
from tests.components_to_test.registry import non_distributed_component_funcs
def move_some_params_to_cuda(model, torch_model): def move_some_params_to_cuda(model, torch_model):
...@@ -16,9 +18,10 @@ def check_params_equal(model, torch_model): ...@@ -16,9 +18,10 @@ def check_params_equal(model, torch_model):
assert torch.allclose(p, torch_p, atol=1e-3), f'diff: {torch.abs(p - torch_p)}' assert torch.allclose(p, torch_p, atol=1e-3), f'diff: {torch.abs(p - torch_p)}'
@pytest.mark.parametrize('nvme_offload_fraction', [0.0, 0.5, 1.0]) @clear_cache_before_run()
@pytest.mark.parametrize('nvme_offload_dir', ['./offload', None]) @parameterize('nvme_offload_fraction', [0.0, 0.5, 1.0])
@pytest.mark.parametrize('adam_cls', [CPUAdam, HybridAdam]) @parameterize('nvme_offload_dir', ['./offload', None])
@parameterize('adam_cls', [CPUAdam, HybridAdam])
def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls):
get_components_func = non_distributed_component_funcs.get_callable('simple_net') get_components_func = non_distributed_component_funcs.get_callable('simple_net')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
......
...@@ -6,13 +6,14 @@ import torch ...@@ -6,13 +6,14 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.distributed.rpc as rpc import torch.distributed.rpc as rpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai import launch
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.pipeline_process_group import ppg
from torch import nn from torch import nn
from torch._C._distributed_rpc import _is_current_rpc_agent_set from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.optim import SGD, Adam, Optimizer, RMSprop from torch.optim import SGD, Adam, Optimizer, RMSprop
from colossalai import launch
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.pipeline_process_group import ppg
rpc_is_initialized = _is_current_rpc_agent_set rpc_is_initialized = _is_current_rpc_agent_set
...@@ -20,7 +21,9 @@ def color_debug(text, prefix=' ', color='blue'): ...@@ -20,7 +21,9 @@ def color_debug(text, prefix=' ', color='blue'):
color = color.upper() color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text) print(getattr(Back, color), prefix, Style.RESET_ALL, text)
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, dim: int, layers: int): def __init__(self, dim: int, layers: int):
super().__init__() super().__init__()
self.layers = torch.nn.ModuleList() self.layers = torch.nn.ModuleList()
...@@ -32,8 +35,10 @@ class MLP(nn.Module): ...@@ -32,8 +35,10 @@ class MLP(nn.Module):
for layer in self.layers: for layer in self.layers:
x = layer(x) x = layer(x)
return x.sum() return x.sum()
class DAG_MLP(nn.Module): class DAG_MLP(nn.Module):
def __init__(self, dim: int, layers: int): def __init__(self, dim: int, layers: int):
super().__init__() super().__init__()
self.layers = torch.nn.ModuleList() self.layers = torch.nn.ModuleList()
...@@ -48,6 +53,7 @@ class DAG_MLP(nn.Module): ...@@ -48,6 +53,7 @@ class DAG_MLP(nn.Module):
y = self.dag_layer(y) y = self.dag_layer(y)
return x.sum(), y.sum() return x.sum(), y.sum()
class RpcTestModel(nn.Module): class RpcTestModel(nn.Module):
def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None:
......
import torch
import pytest
import os import os
import torch.multiprocessing as mp from functools import partial
import torch.distributed.rpc as rpc
from torch import nn import pytest
import torch
import torch.distributed.rpc as rpc
from rpc_test_utils import DAG_MLP, MLP
from torch._C._distributed_rpc import _is_current_rpc_agent_set from torch._C._distributed_rpc import _is_current_rpc_agent_set
from colossalai import launch from colossalai import launch
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.middleware.adaptor import get_fx_topology
from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.fx import ColoTracer
from colossalai.pipeline.middleware.adaptor import get_fx_topology
from rpc_test_utils import MLP, DAG_MLP
from functools import partial
from colossalai.testing import parameterize, rerun_if_address_is_in_use
# global variable for model created # global variable for model created
batch_size = 16 batch_size = 16
dim = 10 dim = 10
rpc_is_initialized = _is_current_rpc_agent_set rpc_is_initialized = _is_current_rpc_agent_set
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
model.eval() model.eval()
tracer = ColoTracer() tracer = ColoTracer()
...@@ -34,13 +34,15 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): ...@@ -34,13 +34,15 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
for submodule in split_submodules: for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule): if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_topo', topo) setattr(submodule, '_topo', topo)
return split_submodules[pp_rank+1] return split_submodules[pp_rank + 1]
def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
torch.manual_seed(1024) torch.manual_seed(1024)
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
return partition return partition
def run_master(model_cls, world_size, forward_only): def run_master(model_cls, world_size, forward_only):
torch.manual_seed(100) torch.manual_seed(100)
...@@ -50,23 +52,27 @@ def run_master(model_cls, world_size, forward_only): ...@@ -50,23 +52,27 @@ def run_master(model_cls, world_size, forward_only):
chunk = 1 chunk = 1
num_microbatches = 8 num_microbatches = 8
use_checkpoint = 'store_true' use_checkpoint = 'store_true'
if model_cls == MLP: if model_cls == MLP:
def data_gen(): def data_gen():
x = torch.zeros((batch_size, dim)) x = torch.zeros((batch_size, dim))
kwargs = dict(x=x) kwargs = dict(x=x)
return kwargs return kwargs
model = model_cls(dim, stage_num * 3) model = model_cls(dim, stage_num * 3)
if forward_only: if forward_only:
labels = None labels = None
else: else:
labels = 1 labels = 1
elif model_cls == DAG_MLP: elif model_cls == DAG_MLP:
def data_gen(): def data_gen():
x = torch.zeros((batch_size, dim)) x = torch.zeros((batch_size, dim))
y = torch.zeros((batch_size, dim)) y = torch.zeros((batch_size, dim))
kwargs = dict(x=x, y=y) kwargs = dict(x=x, y=y)
return kwargs return kwargs
model = model_cls(dim, stage_num * 3) model = model_cls(dim, stage_num * 3)
if forward_only: if forward_only:
labels = None labels = None
...@@ -74,15 +80,17 @@ def run_master(model_cls, world_size, forward_only): ...@@ -74,15 +80,17 @@ def run_master(model_cls, world_size, forward_only):
labels = 1 labels = 1
else: else:
pass pass
data_kwargs = data_gen() data_kwargs = data_gen()
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs), engine = OneFOneBPipelineEngine(
stage_num=stage_num, partition_fn=partial(partition, model, data_kwargs),
num_microbatches=num_microbatches, stage_num=stage_num,
device=device, num_microbatches=num_microbatches,
chunk=chunk, device=device,
checkpoint=use_checkpoint,) chunk=chunk,
checkpoint=use_checkpoint,
)
if not forward_only: if not forward_only:
engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3) engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3)
...@@ -90,13 +98,14 @@ def run_master(model_cls, world_size, forward_only): ...@@ -90,13 +98,14 @@ def run_master(model_cls, world_size, forward_only):
input_x = torch.randn((batch_size, dim), device=device) input_x = torch.randn((batch_size, dim), device=device)
input_y = torch.randn((batch_size, dim), device=device) input_y = torch.randn((batch_size, dim), device=device)
logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only) logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only)
def run_worker(rank, model_cls, world_size, forward_only, master_func):
def run_worker(rank, world_size, port, model_cls, forward_only, master_func):
master_addr = 'localhost' master_addr = 'localhost'
master_port = 29020 master_port = 29020
os.environ['MASTER_ADDR'] = master_addr os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(master_port) os.environ['MASTER_PORT'] = str(master_port)
disable_existing_loggers() disable_existing_loggers()
launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False) launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False)
...@@ -113,7 +122,8 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func): ...@@ -113,7 +122,8 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func):
# barrier here # barrier here
if rpc_is_initialized(): if rpc_is_initialized():
rpc.shutdown() rpc.shutdown()
@pytest.mark.skip("skip due to CI torch version 1.11") @pytest.mark.skip("skip due to CI torch version 1.11")
@parameterize('model_cls', [MLP, DAG_MLP]) @parameterize('model_cls', [MLP, DAG_MLP])
@parameterize('forward_only', [True, False]) @parameterize('forward_only', [True, False])
...@@ -122,7 +132,14 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func): ...@@ -122,7 +132,14 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func):
def test_pp_middleware_fwd(model_cls, forward_only): def test_pp_middleware_fwd(model_cls, forward_only):
world_size = 4 world_size = 4
master_func = run_master master_func = run_master
mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size) spawn(
run_worker,
world_size,
model_cls=model_cls,
forward_only=forward_only,
master_func=master_func,
)
if __name__ == "__main__": if __name__ == "__main__":
test_pp_middleware_fwd() test_pp_middleware_fwd()
\ No newline at end of file
import torch import torch
import torch.multiprocessing as mp
from colossalai.pipeline.pipelinable import PipelinableContext from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.testing import rerun_if_address_is_in_use, rerun_on_exception, spawn
from colossalai.testing import rerun_on_exception
NUM_CHUNKS = 1 NUM_CHUNKS = 1
PIPELINE_SIZE = 2 PIPELINE_SIZE = 2
...@@ -27,7 +25,7 @@ class MLP(torch.nn.Module): ...@@ -27,7 +25,7 @@ class MLP(torch.nn.Module):
return x return x
def run_pipelinable(rank): def run_pipelinable(rank, world_size, port):
pipelinable = PipelinableContext() pipelinable = PipelinableContext()
with pipelinable: with pipelinable:
model = MLP() model = MLP()
...@@ -50,9 +48,9 @@ def run_pipelinable(rank): ...@@ -50,9 +48,9 @@ def run_pipelinable(rank):
assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") @rerun_if_address_is_in_use()
def test_pipelinable(): def test_pipelinable():
mp.spawn(run_pipelinable, nprocs=1) spawn(run_pipelinable, 1)
if __name__ == '__main__': if __name__ == '__main__':
......
import os import os
import torch.distributed.rpc as rpc import torch.distributed.rpc as rpc
import torch.multiprocessing as mp from rpc_test_utils import pg_parse_args, rpc_is_initialized
import pytest
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from rpc_test_utils import pg_parse_args, rpc_is_initialized from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.testing import spawn
def run_worker(rank, args): def run_worker(rank, args):
...@@ -40,4 +39,4 @@ def run_worker(rank, args): ...@@ -40,4 +39,4 @@ def run_worker(rank, args):
if __name__ == "__main__": if __name__ == "__main__":
args = pg_parse_args() args = pg_parse_args()
world_size = args.world_size world_size = args.world_size
mp.spawn(run_worker, args=(args,), nprocs=world_size) spawn(run_worker, world_size, args=args)
\ No newline at end of file
import math import math
import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import pytest
import colossalai import colossalai
import torch.multiprocessing as mp from colossalai.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial
def run(): def run():
...@@ -58,8 +57,7 @@ def run_dist(rank, world_size, port): ...@@ -58,8 +57,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_dist_spec_mgr(world_size): def test_dist_spec_mgr(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) spawn(run_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import torch
import pytest import pytest
from colossalai.tensor import ColoTensor import torch
from numpy import allclose from numpy import allclose
import colossalai import colossalai
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
import torch.multiprocessing as mp from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ReplicaSpec, ShardSpec, distspec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial
def _run_tensor_indexing(): def _run_tensor_indexing():
...@@ -152,8 +146,7 @@ def run_dist_tests(rank, world_size, port): ...@@ -152,8 +146,7 @@ def run_dist_tests(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_dist_cases(world_size): def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) spawn(run_dist_tests, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
...@@ -145,8 +141,7 @@ def run_dist(rank, world_size, port, use_ddp): ...@@ -145,8 +141,7 @@ def run_dist(rank, world_size, port, use_ddp):
@pytest.mark.parametrize('use_ddp', [False, True]) @pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gpt(world_size, use_ddp): def test_gpt(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) spawn(run_dist, world_size, use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoTensor, ProcessGroup from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
...@@ -313,8 +309,7 @@ def run_model_dist(rank, world_size, port): ...@@ -313,8 +309,7 @@ def run_model_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_model(world_size): def test_model(world_size):
run_func = partial(run_model_dist, world_size=world_size, port=free_port()) spawn(run_model_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
def run_pretrain_load_dist(rank, world_size, port): def run_pretrain_load_dist(rank, world_size, port):
...@@ -329,8 +324,7 @@ def run_pretrain_load_dist(rank, world_size, port): ...@@ -329,8 +324,7 @@ def run_pretrain_load_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_pretrain_load(world_size): def test_pretrain_load(world_size):
run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port()) spawn(run_pretrain_load_dist, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment