Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
import colossalai
import torch
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor import ColoTensorSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from functools import partial
from tests.test_tensor.common_utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
class Conv1D(nn.Module):
......@@ -69,8 +66,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_addmm_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':
......
import pytest
import torch
from torch.nn import functional as F
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, tensor_equal, tensor_shard_equal
def run_with_spec(spec_init_func):
......@@ -39,8 +36,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_embedding_bag_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':
......
import pytest
import torch
from torch.nn import functional as F
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
def run_with_spec(spec_init_func, pg: ProcessGroup):
......@@ -40,8 +37,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_embedding_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':
......
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal
def run_with_spec(spec_init_func, split_bias):
......@@ -44,8 +41,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_linear_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':
......
import torch
import pytest
import colossalai
import torch.nn.functional as F
import torch.multiprocessing as mp
from functools import partial
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
from colossalai.utils import get_current_device
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern
def check_cross_entropy():
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
with torch.no_grad():
input_ct.copy_(input_t)
target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
output = F.cross_entropy(input_t, target)
output_colo = F.cross_entropy(input_shard, target)
assert torch.allclose(output_colo, output)
output.backward()
output_colo.backward()
assert torch.allclose(input_t.grad, input_ct.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_cross_entropy()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_loss_func(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_loss_func(1)
import pytest
import torch
import torch.nn.functional as F
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
def check_cross_entropy():
input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True)
with torch.no_grad():
input_ct.copy_(input_t)
target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device())
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
output = F.cross_entropy(input_t, target)
output_colo = F.cross_entropy(input_shard, target)
assert torch.allclose(output_colo, output)
output.backward()
output_colo.backward()
assert torch.allclose(input_t.grad, input_ct.grad)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_cross_entropy()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_loss_func(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_loss_func(1)
import torch
import pytest
import colossalai
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
from functools import partial
from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec
from colossalai.utils import get_current_device
from torch.nn import Parameter
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
def _run_layer_norm():
......@@ -66,8 +64,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_element_wise_ops(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
def run_dist2(rank, world_size, port):
......@@ -79,8 +76,7 @@ def run_dist2(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1])
@rerun_if_address_is_in_use()
def test_ln(world_size):
run_func = partial(run_dist2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist2, world_size)
def check_all():
......
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor, ShardSpec
from colossalai.tensor.distspec import DistPlacementPattern
from tests.test_tensor.common_utils import split_param_row_tp1d, split_param_col_tp1d, debug_print
def exam_view_core(pg):
# the case of replicated ColoTensors
x = torch.randn(4, 4).cuda()
x_colo = ColoTensor(x, ColoTensorSpec(pg))
y = x.view(2, -1, 2)
y_colo = x_colo.view(2, -1, 2)
assert torch.all(y == y_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
# the perfect case of col-sliced ColoTensors
split_param_col_tp1d(x_colo, pg)
z = x.view(torch.Size((2, 1, 2, -1)))
z_colo = x_colo.view(torch.Size((2, 1, 2, -1)))
if dist.get_rank() == 0:
z = z[:, :, :, 0:2]
else:
z = z[:, :, :, 2:]
assert torch.all(z == z_colo)
assert z_colo.dist_spec == x_colo.dist_spec
# the perfect case of row-sliced ColoTensors
split_param_row_tp1d(x_colo, pg)
z = x.view(torch.Size((-1, 2, 2)))
z_colo = x_colo.view(torch.Size((-1, 2, 2)))
if dist.get_rank() == 0:
z = z[0:2, :, :]
else:
z = z[2:, :, :]
assert torch.all(z == z_colo)
assert z_colo.dist_spec == x_colo.dist_spec
# the normal case of row-sliced ColoTensors
z = x.view(-1, 2, 2, 2)
z_colo = x_colo.view(-1, 2, 2, 2)
assert torch.all(z == z_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
def exam_view_autograd(pg):
x = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
y = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
with torch.no_grad():
y.copy_(x)
y = ColoTensor(y, ColoTensorSpec(pg))
y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
xx = x.view(2, 2, -1)
yy_slice = y_slice.view(2, 2, -1)
yy = yy_slice.to_replicate()
grad = torch.randn(2, 2, 4, device=get_current_device())
xx.backward(grad)
yy.backward(grad)
assert torch.all(x.grad == y.grad)
def exam_view_errors(pg):
x = torch.randn(8, 2, device=get_current_device())
x = ColoTensor(x, ColoTensorSpec(pg))
split_param_row_tp1d(x, pg)
x.view('a', 'b', 'c')
x.view(8, -1)
x.view([-2, -2, -2])
x.view((-1, -1, -1))
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
exam_view_core(pg)
exam_view_autograd(pg)
# exam_view_errors(pg)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_view(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_view(2)
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec
from colossalai.tensor.distspec import DistPlacementPattern
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from tests.test_tensor.common_utils import debug_print, split_param_col_tp1d, split_param_row_tp1d
def exam_view_core(pg):
# the case of replicated ColoTensors
x = torch.randn(4, 4).cuda()
x_colo = ColoTensor(x, ColoTensorSpec(pg))
y = x.view(2, -1, 2)
y_colo = x_colo.view(2, -1, 2)
assert torch.all(y == y_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
# the perfect case of col-sliced ColoTensors
split_param_col_tp1d(x_colo, pg)
z = x.view(torch.Size((2, 1, 2, -1)))
z_colo = x_colo.view(torch.Size((2, 1, 2, -1)))
if dist.get_rank() == 0:
z = z[:, :, :, 0:2]
else:
z = z[:, :, :, 2:]
assert torch.all(z == z_colo)
assert z_colo.dist_spec == x_colo.dist_spec
# the perfect case of row-sliced ColoTensors
split_param_row_tp1d(x_colo, pg)
z = x.view(torch.Size((-1, 2, 2)))
z_colo = x_colo.view(torch.Size((-1, 2, 2)))
if dist.get_rank() == 0:
z = z[0:2, :, :]
else:
z = z[2:, :, :]
assert torch.all(z == z_colo)
assert z_colo.dist_spec == x_colo.dist_spec
# the normal case of row-sliced ColoTensors
z = x.view(-1, 2, 2, 2)
z_colo = x_colo.view(-1, 2, 2, 2)
assert torch.all(z == z_colo)
assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE
def exam_view_autograd(pg):
x = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
y = torch.randn(8, 2, device=get_current_device(), requires_grad=True)
with torch.no_grad():
y.copy_(x)
y = ColoTensor(y, ColoTensorSpec(pg))
y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()]))
xx = x.view(2, 2, -1)
yy_slice = y_slice.view(2, 2, -1)
yy = yy_slice.to_replicate()
grad = torch.randn(2, 2, 4, device=get_current_device())
xx.backward(grad)
yy.backward(grad)
assert torch.all(x.grad == y.grad)
def exam_view_errors(pg):
x = torch.randn(8, 2, device=get_current_device())
x = ColoTensor(x, ColoTensorSpec(pg))
split_param_row_tp1d(x, pg)
x.view('a', 'b', 'c')
x.view(8, -1)
x.view([-2, -2, -2])
x.view((-1, -1, -1))
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
exam_view_core(pg)
exam_view_autograd(pg)
# exam_view_errors(pg)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_view(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_view(2)
......@@ -2,7 +2,7 @@ import math
import torch
from colossalai.testing import parameterize
from colossalai.testing import clear_cache_before_run, parameterize
def torch_adam_update(
......@@ -46,6 +46,7 @@ def assertTrue(condition, msg):
assert condition, msg
@clear_cache_before_run()
@parameterize('adamw', [True, False])
@parameterize('step', [1, 2])
@parameterize('p_dtype', [torch.float, torch.half])
......
import torch
import torch.nn as nn
from torch.optim.adam import Adam
from torch.optim import AdamW
from torch.optim.adam import Adam
from colossalai.nn.optimizer.fused_adam import FusedAdam
from colossalai.testing import parameterize
from colossalai.testing import clear_cache_before_run, parameterize
class FC(nn.Module):
......@@ -17,6 +17,7 @@ class FC(nn.Module):
return self.fc(x)
@clear_cache_before_run()
@parameterize('adamw', [False, True])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
......
......@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from numpy import dtype
from colossalai.testing import parameterize
from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.utils import multi_tensor_applier
......@@ -41,6 +41,7 @@ def torch_adam_update(
param.addcdiv_(exp_avg, denom, value=-step_size)
@clear_cache_before_run()
@parameterize('adamw', [False, True])
@parameterize('step', [1, 2])
@parameterize('p_dtype', [torch.float, torch.half])
......
......@@ -4,11 +4,12 @@ from torch.optim import AdamW
from torch.optim.adam import Adam
from colossalai.nn.optimizer.hybrid_adam import HybridAdam
from colossalai.testing import parameterize
from colossalai.testing import clear_cache_before_run, parameterize
RE = 3
@clear_cache_before_run()
@parameterize('adamw', [False, True])
@parameterize('device', ['cpu', 'cuda:0'])
@parameterize('p_dtype', [torch.float])
......
import pytest
import torch
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.testing import clear_cache_before_run, parameterize
from tests.components_to_test.registry import non_distributed_component_funcs
def move_some_params_to_cuda(model, torch_model):
......@@ -16,9 +18,10 @@ def check_params_equal(model, torch_model):
assert torch.allclose(p, torch_p, atol=1e-3), f'diff: {torch.abs(p - torch_p)}'
@pytest.mark.parametrize('nvme_offload_fraction', [0.0, 0.5, 1.0])
@pytest.mark.parametrize('nvme_offload_dir', ['./offload', None])
@pytest.mark.parametrize('adam_cls', [CPUAdam, HybridAdam])
@clear_cache_before_run()
@parameterize('nvme_offload_fraction', [0.0, 0.5, 1.0])
@parameterize('nvme_offload_dir', ['./offload', None])
@parameterize('adam_cls', [CPUAdam, HybridAdam])
def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls):
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
......
......@@ -6,13 +6,14 @@ import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
from colossalai import launch
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.pipeline_process_group import ppg
from torch import nn
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from torch.optim import SGD, Adam, Optimizer, RMSprop
from colossalai import launch
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.pipeline_process_group import ppg
rpc_is_initialized = _is_current_rpc_agent_set
......@@ -20,7 +21,9 @@ def color_debug(text, prefix=' ', color='blue'):
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
class MLP(nn.Module):
def __init__(self, dim: int, layers: int):
super().__init__()
self.layers = torch.nn.ModuleList()
......@@ -32,8 +35,10 @@ class MLP(nn.Module):
for layer in self.layers:
x = layer(x)
return x.sum()
class DAG_MLP(nn.Module):
def __init__(self, dim: int, layers: int):
super().__init__()
self.layers = torch.nn.ModuleList()
......@@ -48,6 +53,7 @@ class DAG_MLP(nn.Module):
y = self.dag_layer(y)
return x.sum(), y.sum()
class RpcTestModel(nn.Module):
def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None:
......
import torch
import pytest
import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
from functools import partial
from torch import nn
import pytest
import torch
import torch.distributed.rpc as rpc
from rpc_test_utils import DAG_MLP, MLP
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from colossalai import launch
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.middleware.adaptor import get_fx_topology
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
from colossalai.pipeline.middleware.adaptor import get_fx_topology
from rpc_test_utils import MLP, DAG_MLP
from functools import partial
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# global variable for model created
batch_size = 16
dim = 10
rpc_is_initialized = _is_current_rpc_agent_set
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
model.eval()
tracer = ColoTracer()
......@@ -34,13 +34,15 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_topo', topo)
return split_submodules[pp_rank+1]
return split_submodules[pp_rank + 1]
def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int):
torch.manual_seed(1024)
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs)
return partition
def run_master(model_cls, world_size, forward_only):
torch.manual_seed(100)
......@@ -50,23 +52,27 @@ def run_master(model_cls, world_size, forward_only):
chunk = 1
num_microbatches = 8
use_checkpoint = 'store_true'
if model_cls == MLP:
def data_gen():
x = torch.zeros((batch_size, dim))
kwargs = dict(x=x)
return kwargs
model = model_cls(dim, stage_num * 3)
if forward_only:
labels = None
else:
labels = 1
elif model_cls == DAG_MLP:
def data_gen():
x = torch.zeros((batch_size, dim))
y = torch.zeros((batch_size, dim))
kwargs = dict(x=x, y=y)
return kwargs
model = model_cls(dim, stage_num * 3)
if forward_only:
labels = None
......@@ -74,15 +80,17 @@ def run_master(model_cls, world_size, forward_only):
labels = 1
else:
pass
data_kwargs = data_gen()
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs),
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=chunk,
checkpoint=use_checkpoint,)
engine = OneFOneBPipelineEngine(
partition_fn=partial(partition, model, data_kwargs),
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=chunk,
checkpoint=use_checkpoint,
)
if not forward_only:
engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3)
......@@ -90,13 +98,14 @@ def run_master(model_cls, world_size, forward_only):
input_x = torch.randn((batch_size, dim), device=device)
input_y = torch.randn((batch_size, dim), device=device)
logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only)
def run_worker(rank, model_cls, world_size, forward_only, master_func):
def run_worker(rank, world_size, port, model_cls, forward_only, master_func):
master_addr = 'localhost'
master_port = 29020
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(master_port)
disable_existing_loggers()
launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False)
......@@ -113,7 +122,8 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func):
# barrier here
if rpc_is_initialized():
rpc.shutdown()
@pytest.mark.skip("skip due to CI torch version 1.11")
@parameterize('model_cls', [MLP, DAG_MLP])
@parameterize('forward_only', [True, False])
......@@ -122,7 +132,14 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func):
def test_pp_middleware_fwd(model_cls, forward_only):
world_size = 4
master_func = run_master
mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size)
spawn(
run_worker,
world_size,
model_cls=model_cls,
forward_only=forward_only,
master_func=master_func,
)
if __name__ == "__main__":
test_pp_middleware_fwd()
\ No newline at end of file
test_pp_middleware_fwd()
import torch
import torch.multiprocessing as mp
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.testing import rerun_on_exception
from colossalai.testing import rerun_if_address_is_in_use, rerun_on_exception, spawn
NUM_CHUNKS = 1
PIPELINE_SIZE = 2
......@@ -27,7 +25,7 @@ class MLP(torch.nn.Module):
return x
def run_pipelinable(rank):
def run_pipelinable(rank, world_size, port):
pipelinable = PipelinableContext()
with pipelinable:
model = MLP()
......@@ -50,9 +48,9 @@ def run_pipelinable(rank):
assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
@rerun_if_address_is_in_use()
def test_pipelinable():
mp.spawn(run_pipelinable, nprocs=1)
spawn(run_pipelinable, 1)
if __name__ == '__main__':
......
import os
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import pytest
from rpc_test_utils import pg_parse_args, rpc_is_initialized
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from rpc_test_utils import pg_parse_args, rpc_is_initialized
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.testing import spawn
def run_worker(rank, args):
......@@ -40,4 +39,4 @@ def run_worker(rank, args):
if __name__ == "__main__":
args = pg_parse_args()
world_size = args.world_size
mp.spawn(run_worker, args=(args,), nprocs=world_size)
\ No newline at end of file
spawn(run_worker, world_size, args=args)
import math
import pytest
import torch
import torch.distributed as dist
import pytest
import colossalai
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial
from colossalai.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
def run():
......@@ -58,8 +57,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_dist_spec_mgr(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':
......
import torch
import pytest
from colossalai.tensor import ColoTensor
import torch
from numpy import allclose
import colossalai
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec
from colossalai.core import global_context as gpc
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec
from functools import partial
from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ReplicaSpec, ShardSpec, distspec
from colossalai.testing import rerun_if_address_is_in_use, spawn
def _run_tensor_indexing():
......@@ -152,8 +146,7 @@ def run_dist_tests(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist_tests, world_size)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
......@@ -145,8 +141,7 @@ def run_dist(rank, world_size, port, use_ddp):
@pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use()
def test_gpt(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, use_ddp=use_ddp)
if __name__ == '__main__':
......
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
......@@ -313,8 +309,7 @@ def run_model_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_model(world_size):
run_func = partial(run_model_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_model_dist, world_size)
def run_pretrain_load_dist(rank, world_size, port):
......@@ -329,8 +324,7 @@ def run_pretrain_load_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_pretrain_load(world_size):
run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_pretrain_load_dist, world_size)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment