Unverified Commit 060b917d authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[refactor] remove gpc dependency in colotensor's _ops (#1189)

parent abf6a262
...@@ -7,12 +7,12 @@ import torch.multiprocessing as mp ...@@ -7,12 +7,12 @@ import torch.multiprocessing as mp
from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import _get_default_group
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import DistSpecManager, distspec from colossalai.tensor import DistSpecManager, distspec, ProcessGroup
from functools import partial from functools import partial
def run(): def run():
group = _get_default_group() group = ProcessGroup(tp_degree=dist.get_world_size())
rank = dist.get_rank() rank = dist.get_rank()
size = dist.get_world_size() size = dist.get_world_size()
depth = int(math.sqrt(size)) depth = int(math.sqrt(size))
...@@ -34,7 +34,7 @@ def run(): ...@@ -34,7 +34,7 @@ def run():
def check_mem(): def check_mem():
group = _get_default_group() group = ProcessGroup(tp_degree=dist.get_world_size())
size = dist.get_world_size() size = dist.get_world_size()
assert torch.cuda.memory_allocated() == 0 assert torch.cuda.memory_allocated() == 0
x = torch.rand(32, 32).cuda() x = torch.rand(32, 32).cuda()
......
import torch import torch
from colossalai.context.parallel_mode import ParallelMode from colossalai.tensor import distspec, ColoParameter
from colossalai.tensor import ColoTensor, distspec, ColoParameter
from torch.nn import functional as F from torch.nn import functional as F
from functools import partial from functools import partial
...@@ -10,23 +9,21 @@ import torch ...@@ -10,23 +9,21 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from _utils import tensor_equal, tensor_shard_equal from _utils import tensor_equal, tensor_shard_equal
def init_1d_col(weight): def init_1d_col(weight, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def run_with_spec(spec_init_func): def run_with_spec(spec_init_func):
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
model = torch.nn.EmbeddingBag(10, 4).cuda() model = torch.nn.EmbeddingBag(10, 4).cuda()
weight = ColoParameter(model.weight.clone()) weight = ColoParameter(model.weight.clone())
spec_init_func(weight) spec_init_func(weight, pg)
inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda() inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda()
offsets = torch.tensor([0, 4]).cuda() offsets = torch.tensor([0, 4]).cuda()
out = model(inputs, offsets=offsets) out = model(inputs, offsets=offsets)
...@@ -35,7 +32,7 @@ def run_with_spec(spec_init_func): ...@@ -35,7 +32,7 @@ def run_with_spec(spec_init_func):
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, weight.grad) assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
......
import torch import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import ColoTensor, distspec from colossalai.tensor import ColoTensor, distspec
from torch.nn import functional as F from torch.nn import functional as F
from functools import partial from functools import partial
...@@ -11,30 +10,26 @@ import torch.multiprocessing as mp ...@@ -11,30 +10,26 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight): def init_1d_row(weight, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col(weight): def init_1d_col(weight, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def run_with_spec(spec_init_func): def run_with_spec(spec_init_func, pg: ProcessGroup):
model = torch.nn.Embedding(12, 32).cuda() model = torch.nn.Embedding(12, 32).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach())) weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
spec_init_func(weight) spec_init_func(weight, pg)
x = torch.tensor((0, 3, 6, 9)).cuda() x = torch.tensor((0, 3, 6, 9)).cuda()
out = model(x) out = model(x)
colo_out = F.embedding(x, weight) colo_out = F.embedding(x, weight)
...@@ -42,14 +37,16 @@ def run_with_spec(spec_init_func): ...@@ -42,14 +37,16 @@ def run_with_spec(spec_init_func):
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, weight.grad) # compare grad inside a TP group
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) # config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_row) pg = ProcessGroup(tp_degree=world_size)
run_with_spec(init_1d_col) run_with_spec(init_1d_row, pg)
run_with_spec(init_1d_col, pg)
@pytest.mark.dist @pytest.mark.dist
......
import pytest import pytest
import colossalai import colossalai
from colossalai.context.parallel_mode import ParallelMode
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
from colossalai.core import global_context as gpc
from functools import partial from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
import torch
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
def init_1d_row_spec(model): def init_1d_row_spec(model, pg: ProcessGroup):
spec = TensorSpec( tensor_spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
p.set_tensor_spec(spec) p.set_tensor_spec(tensor_spec)
def init_1d_col_spec(model): def init_1d_col_spec(model, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_tensor_spec(spec) p.set_tensor_spec(spec)
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert tensor_shard_equal(torch_p, p) assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
assert pg.tp_world_size() is not None
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
def check_grad_equal(model, torch_model): def check_grad_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert tensor_shard_equal(torch_p.grad, p.grad) assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_gpt(init_spec_func, use_ddp): def run_gpt(init_spec_func, use_ddp):
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
...@@ -54,21 +57,25 @@ def run_gpt(init_spec_func, use_ddp): ...@@ -54,21 +57,25 @@ def run_gpt(init_spec_func, use_ddp):
model = model.cuda() model = model.cuda()
torch_model = model_builder().cuda() torch_model = model_builder().cuda()
if use_ddp: if use_ddp:
model = ColoDDP(model) # torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg)
# torch.distributed.barrier()
torch_model = DDP(torch_model, torch_model = DDP(torch_model,
device_ids=[gpc.get_global_rank()], device_ids=[gpc.get_global_rank()],
process_group=gpc.get_group(ParallelMode.DATA)) process_group=gpc.get_group(ParallelMode.DATA))
model = ColoDDP(model, process_group=pg)
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p) torch_p.data.copy_(p)
init_spec_func(model) init_spec_func(model, pg)
check_param_equal(model, torch_model) check_param_equal(model, torch_model, pg)
model.train() model.train()
torch_model.train() torch_model.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA)) set_seed(pg.tp_local_rank())
for i, (input_ids, attn_mask) in enumerate(train_dataloader): for i, (input_ids, attn_mask) in enumerate(train_dataloader):
logits = model(input_ids, attn_mask) logits = model(input_ids, attn_mask)
torch_logits = torch_model(input_ids, attn_mask) torch_logits = torch_model(input_ids, attn_mask)
assert tensor_equal(torch_logits, logits) assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
loss = criterion(logits, input_ids) loss = criterion(logits, input_ids)
torch_loss = criterion(torch_logits, input_ids) torch_loss = criterion(torch_logits, input_ids)
if use_ddp: if use_ddp:
...@@ -76,7 +83,7 @@ def run_gpt(init_spec_func, use_ddp): ...@@ -76,7 +83,7 @@ def run_gpt(init_spec_func, use_ddp):
else: else:
loss.backward() loss.backward()
torch_loss.backward() torch_loss.backward()
check_grad_equal(model, torch_model) check_grad_equal(model, torch_model, pg)
if i > 0: if i > 0:
break break
...@@ -87,11 +94,12 @@ def run_dist(rank, world_size, port, use_ddp): ...@@ -87,11 +94,12 @@ def run_dist(rank, world_size, port, use_ddp):
tp_world_size = world_size // 2 if use_ddp else world_size tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt(init_1d_row_spec, use_ddp) # run_gpt(init_1d_row_spec, use_ddp)
run_gpt(init_1d_col_spec, use_ddp) run_gpt(init_1d_col_spec, use_ddp)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.skip("under development")
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True]) @pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
......
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ComputePattern, ComputeSpec
from functools import partial
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.nn.parallel.layers import init_colo_module
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.nn.optimizer import ColoOptimizer
import colossalai
import torch
import torch.multiprocessing as mp
import pytest
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.embed = torch.nn.Embedding(20, 4)
self.proj = torch.nn.Linear(4, 8)
def forward(self, x):
# move input to cpu and restore output
current_dev = x.device
x = x.to('cpu')
x = self.embed(x)
x = x.to(current_dev)
x = self.proj(x)
return x
def run_hybrid_device(use_ddp, mode):
with ColoInitContext(device=get_current_device()):
model = Net()
real_model = model
if use_ddp:
model = ColoDDP(model)
real_model = model.module
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
parallel_action = ComputeSpec(ComputePattern.TP1D)
init_colo_module(model, parallel_action, recursive=True, mode=mode)
# use cpu gloo to handle embedding
real_model.embed.to('cpu')
gloo_group_tp = gpc.get_cpu_group(ParallelMode.PARALLEL_1D)
real_model.embed.weight.spec.dist_spec.process_group = gloo_group_tp
print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}')
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
out = model(data)
out.sum().backward()
optimizer.step()
def run_dist(rank, world_size, port, use_ddp, mode):
if use_ddp and world_size == 1:
return
tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_hybrid_device(use_ddp, mode)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.parametrize('use_ddp', [False, True])
@pytest.mark.parametrize('mode', ['col', 'row'])
@rerun_if_address_is_in_use()
# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP)
def _test_hybrid_device(world_size, use_ddp, mode):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, mode=mode)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
_test_hybrid_device(4, True, 'row')
...@@ -12,32 +12,29 @@ import torch.nn.functional as F ...@@ -12,32 +12,29 @@ import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, bias): def init_1d_row(weight, bias, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col(weight, bias): def init_1d_col(weight, bias, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
bias.set_tensor_spec(spec) bias.set_tensor_spec(spec)
def run_with_spec(spec_init_func): def run_with_spec(spec_init_func):
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
model = torch.nn.Linear(4, 8).cuda() model = torch.nn.Linear(4, 8).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach())) weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach())) bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
spec_init_func(weight, bias) spec_init_func(weight, bias, pg)
x = torch.rand(2, 4).cuda() x = torch.rand(2, 4).cuda()
out = model(x) out = model(x)
colo_out = F.linear(x, weight, bias) colo_out = F.linear(x, weight, bias)
...@@ -46,8 +43,8 @@ def run_with_spec(spec_init_func): ...@@ -46,8 +43,8 @@ def run_with_spec(spec_init_func):
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, weight.grad) assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model.bias.grad, bias.grad) assert tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
......
from colossalai.tensor.colo_parameter import ColoParameter
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
import pytest import pytest
from functools import partial
from _utils import tensor_shard_equal, set_seed
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor.colo_parameter import ColoParameter
import colossalai
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
...@@ -12,34 +14,30 @@ from colossalai.utils.model.colo_init_context import ColoInitContext ...@@ -12,34 +14,30 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \ from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
from colossalai.nn.optimizer import ColoOptimizer from colossalai.nn.optimizer import ColoOptimizer
from functools import partial
from _utils import tensor_shard_equal, set_seed from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_linear(weight, pg: ProcessGroup): def init_1d_row_linear(weight, pg: ProcessGroup):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]), spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col_linear(weight, pg): def init_1d_col_linear(weight, pg):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]), spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_row_embedding(weight, pg): def init_1d_row_embedding(weight, pg):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]), spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col_embedding(weight, pg): def init_1d_col_embedding(weight, pg):
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]), spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
...@@ -142,7 +140,7 @@ def run_1d_hybrid_tp(model_name): ...@@ -142,7 +140,7 @@ def run_1d_hybrid_tp(model_name):
with torch.no_grad(): with torch.no_grad():
# check param # check param
for p, torch_p in zip(model.parameters(), model_torch.parameters()): for p, torch_p in zip(model.parameters(), model_torch.parameters()):
assert tensor_shard_equal(torch_p, p) assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
if i > 5: if i > 5:
break break
......
...@@ -13,12 +13,10 @@ import colossalai ...@@ -13,12 +13,10 @@ import colossalai
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.context.parallel_mode import ParallelMode from colossalai.tensor import distspec, ProcessGroup
from colossalai.tensor import distspec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
...@@ -26,7 +24,9 @@ from tests.components_to_test.registry import non_distributed_component_funcs ...@@ -26,7 +24,9 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def run_model_with_spec(mode, model_name): def run_model_with_spec(mode, model_name):
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
rank = pg.rank()
set_seed(1) set_seed(1)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
...@@ -40,28 +40,28 @@ def run_model_with_spec(mode, model_name): ...@@ -40,28 +40,28 @@ def run_model_with_spec(mode, model_name):
for p1, p2 in zip(model.parameters(), model_seq.parameters()): for p1, p2 in zip(model.parameters(), model_seq.parameters()):
p2.data.copy_(p1.data) p2.data.copy_(p1.data)
parallel_action = ComputeSpec(ComputePattern.TP1D) compute_spec = ComputeSpec(ComputePattern.TP1D)
# Not all layers in Bert can be mod by 4. # Not all layers in Bert can be mod by 4.
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2. # e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
if 'bert' == model_name: if 'bert' == model_name:
if 'col' == mode: if 'col' == mode:
init_colo_module(model.bert.embeddings, parallel_action, recursive=True, mode=mode) init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode=mode)
init_colo_module(model.bert.encoder, parallel_action, recursive=True, mode=mode) init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
init_colo_module(model.classifier, parallel_action, recursive=True, mode='row') init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode='row')
elif 'row' == mode: elif 'row' == mode:
init_colo_module(model.bert.embeddings, parallel_action, recursive=True, mode='col') init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode='col')
init_colo_module(model.bert.encoder, parallel_action, recursive=True, mode=mode) init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
init_colo_module(model.classifier, parallel_action, recursive=True, mode=mode) init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode=mode)
elif 'simple_net' == model_name: elif 'simple_net' == model_name:
init_colo_module(model, parallel_action, recursive=True, mode=mode) init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
model = model.cuda() model = model.cuda()
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
data = data.to(get_current_device()) data = data.to(get_current_device())
label = label.to(get_current_device()) label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
if criterion: if criterion:
output = model(data) output = model(data)
...@@ -113,9 +113,10 @@ def run_linear_with_spec(mode): ...@@ -113,9 +113,10 @@ def run_linear_with_spec(mode):
model = torch.nn.Linear(4, 8) model = torch.nn.Linear(4, 8)
model_handy = copy(model) model_handy = copy(model)
world_size = torch.distributed.get_world_size()
parallel_action = ComputeSpec(ComputePattern.TP1D) pg = ProcessGroup(tp_degree=world_size)
init_colo_module(model, parallel_action, recursive=True, mode=mode) compute_spec = ComputeSpec(ComputePattern.TP1D)
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
x = torch.rand(2, 4).cuda() x = torch.rand(2, 4).cuda()
out = model(x) out = model(x)
...@@ -124,8 +125,8 @@ def run_linear_with_spec(mode): ...@@ -124,8 +125,8 @@ def run_linear_with_spec(mode):
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad) assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad) assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_check_shared_param(): def run_check_shared_param():
...@@ -136,6 +137,10 @@ def run_check_shared_param(): ...@@ -136,6 +137,10 @@ def run_check_shared_param():
num_layer = 2 num_layer = 2
vocab_size = 24 vocab_size = 24
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
rank = pg.rank()
config = BertConfig(vocab_size=vocab_size, config = BertConfig(vocab_size=vocab_size,
hidden_size=hidden_dim, hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4, intermediate_size=hidden_dim * 4,
...@@ -148,18 +153,16 @@ def run_check_shared_param(): ...@@ -148,18 +153,16 @@ def run_check_shared_param():
model = BertForMaskedLM(config) model = BertForMaskedLM(config)
model = model.cuda() model = model.cuda()
parallel_action = ComputeSpec(ComputePattern.TP1D) compute_spec = ComputeSpec(ComputePattern.TP1D)
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec # model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2 assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2
# They are all Linear, so both row is allowed. This should pass check. # They are all Linear, so both row is allowed. This should pass check.
init_colo_module(model, parallel_action, recursive=True, mode='row') init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
# This should be detected by check because you can not set weight as row while set bias as col. # This should be detected by check because you can not set weight as row while set bias as col.
col_spec = TensorSpec( col_spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
model.cls.predictions.bias.set_tensor_spec(col_spec) model.cls.predictions.bias.set_tensor_spec(col_spec)
try: try:
check_colo_module(model.cls.predictions.decoder, recursive=False) check_colo_module(model.cls.predictions.decoder, pg=pg, recursive=False)
except Exception as e: except Exception as e:
assert 'incorrectly sharded' in str(e) assert 'incorrectly sharded' in str(e)
......
...@@ -4,10 +4,9 @@ import colossalai ...@@ -4,10 +4,9 @@ import colossalai
import torch.nn.functional as F import torch.nn.functional as F
import torch.multiprocessing as mp import torch.multiprocessing as mp
from functools import partial from functools import partial
from colossalai.tensor import ColoTensor, ColoParameter from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch.nn import Parameter from torch.nn import Parameter
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import distspec, TensorSpec from colossalai.tensor import distspec, TensorSpec
...@@ -43,9 +42,10 @@ def check_spec_eq(tensor, other): ...@@ -43,9 +42,10 @@ def check_spec_eq(tensor, other):
def check_element_wise_ops(): def check_element_wise_ops():
pg = _get_default_group() world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
t = torch.rand(2, 2) t = torch.rand(2, 2)
x = ColoTensor(t, spec=TensorSpec(distspec.shard(pg, [0], [pg.size()]))) x = ColoTensor(t, spec=TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()])))
check_spec_eq(x, x.cuda()) check_spec_eq(x, x.cuda())
assert torch.equal(x.cuda(), t.cuda()) assert torch.equal(x.cuda(), t.cuda())
check_spec_eq(x, torch.abs(x)) check_spec_eq(x, torch.abs(x))
......
...@@ -11,7 +11,6 @@ import torch.multiprocessing as mp ...@@ -11,7 +11,6 @@ import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup
from colossalai.context import ParallelMode
from functools import partial from functools import partial
...@@ -55,11 +54,9 @@ def test_operand(): ...@@ -55,11 +54,9 @@ def test_operand():
def _run_view(world_size): def _run_view(world_size):
t_ref = torch.randn(4, 5) t_ref = torch.randn(4, 5)
rank = gpc.get_global_rank() rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size))) pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
assert pg.dp_world_size() == world_size, f"{pg.dp_world_size()} vs {world_size}"
t = ColoTensor.from_torch_tensor( t = ColoTensor.from_torch_tensor(
t_ref, t_ref, TensorSpec(distspec.shard(process_group=pg, dims=[0], num_partitions=[pg.tp_world_size()])))
TensorSpec(distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])))
assert t.size_global()[0] == 4 * world_size assert t.size_global()[0] == 4 * world_size
assert t.size_global(1) == 5 assert t.size_global(1) == 5
...@@ -77,12 +74,12 @@ def _run_tensor_shard_init(world_size): ...@@ -77,12 +74,12 @@ def _run_tensor_shard_init(world_size):
t_ref = torch.randn(4, 5) t_ref = torch.randn(4, 5)
rank = gpc.get_global_rank() rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size))) pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
shard_spec = distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()]) shard_spec = distspec.shard(process_group=pg, dims=[0], num_partitions=[pg.tp_world_size()])
tensor_spec = TensorSpec(shard_spec) tensor_spec = TensorSpec(shard_spec)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate())) t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
assert t.shape == torch.Size((4 * world_size, 5)) assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
def _run_tensor_replicated_init(world_size): def _run_tensor_replicated_init(world_size):
...@@ -92,11 +89,19 @@ def _run_tensor_replicated_init(world_size): ...@@ -92,11 +89,19 @@ def _run_tensor_replicated_init(world_size):
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}" assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
def _run_process_group(world_size):
pg1 = ProcessGroup()
pg2 = ProcessGroup()
assert pg1 == pg2
def run_dist_tests(rank, world_size, port): def run_dist_tests(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_tensor_shard_init(world_size) _run_tensor_shard_init(world_size)
_run_tensor_replicated_init(world_size) _run_tensor_replicated_init(world_size)
_run_view(world_size) _run_view(world_size)
_run_process_group(world_size)
@pytest.mark.dist @pytest.mark.dist
......
...@@ -2,13 +2,11 @@ import pytest ...@@ -2,13 +2,11 @@ import pytest
import colossalai import colossalai
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.gemini import ChunkManager from colossalai.gemini import ChunkManager
from colossalai.core import global_context as gpc
from functools import partial from functools import partial
from _utils import tensor_equal, set_seed, tensor_shard_equal from _utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
...@@ -19,20 +17,22 @@ from colossalai.zero import ZeroOptimizer ...@@ -19,20 +17,22 @@ from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
def check_param_equal(model, torch_model): def check_param_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for p, torch_p in zip(model.parameters(), torch_model.parameters()):
if p.storage().size() > 0: if p.storage().size() > 0:
assert p.dtype == torch.half assert p.dtype == torch.half
assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}' assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(),
pg.tp_world_size()), f'{torch_p} vs {p}'
def check_grad_equal(model, torch_model): def check_grad_equal(model, torch_model, pg: ProcessGroup):
for p, torch_p in zip(model.parameters(), torch_model.parameters()): for p, torch_p in zip(model.parameters(), torch_model.parameters()):
if p.grad is not None: if p.grad is not None:
assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad) assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
pg.tp_local_rank(), pg.tp_world_size())
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
...@@ -44,20 +44,16 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask): ...@@ -44,20 +44,16 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
return logits return logits
def init_1d_row_spec(model): def init_1d_row_spec(model, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
p.set_tensor_spec(spec) p.set_tensor_spec(spec)
def init_1d_col_spec(model): def init_1d_col_spec(model, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):
...@@ -79,44 +75,51 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): ...@@ -79,44 +75,51 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p) torch_p.data.copy_(p)
world_size = torch.distributed.get_world_size()
# world size, dp = 2, tp =2, construct a hybrid parallelism.
if world_size == 4:
pg = ProcessGroup(tp_degree=2)
else:
pg = ProcessGroup(tp_degree=world_size)
if tp_init_spec_func: if tp_init_spec_func:
tp_init_spec_func(model) tp_init_spec_func(model, pg)
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, chunk_manager = ChunkManager(chunk_size,
enable_distributed_storage=use_zero, enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy)) init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager) model = ZeroDDP(model, gemini_manager, pg)
optim = HybridAdam(model.parameters(), lr=1e-3) optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=32) optim = ZeroOptimizer(optim, model, initial_scale=32)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA)) torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
print(chunk_manager) # print(chunk_manager)
check_param_equal(model, torch_model) check_param_equal(model, torch_model, pg)
model.train() model.train()
torch_model.train() torch_model.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA)) set_seed(pg.dp_local_rank())
for i, (input_ids, attn_mask) in enumerate(train_dataloader): for i, (input_ids, attn_mask) in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
logits = run_fwd_bwd(model, criterion, optim, input_ids, attn_mask) logits = run_fwd_bwd(model, criterion, optim, input_ids, attn_mask)
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask) torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
assert tensor_equal(logits, torch_logits) assert tensor_equal(logits, torch_logits)
check_grad_equal(model, torch_model) check_grad_equal(model, torch_model, pg)
optim.step() optim.step()
torch_optim.step() torch_optim.step()
check_param_equal(model, torch_model) check_param_equal(model, torch_model, pg)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
if world_size == 4:
config['parallel'] = {'tensor': {'mode': '1d', 'size': 2}}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if world_size == 4: if world_size == 4:
run_gpt(tp_init_spec_func=init_1d_col_spec) run_gpt(tp_init_spec_func=init_1d_col_spec)
...@@ -126,6 +129,7 @@ def run_dist(rank, world_size, port): ...@@ -126,6 +129,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.skip("under development")
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gpt(world_size): def test_gpt(world_size):
......
import pytest import pytest
import colossalai import colossalai
import torch import torch
from colossalai.context.parallel_mode import ParallelMode
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from functools import partial from functools import partial
from tests.test_tensor._utils import set_seed from tests.test_tensor._utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
...@@ -16,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext ...@@ -16,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.tensor import ProcessGroup
def init_zero(model_builder, placement_policy): def init_zero(model_builder, placement_policy):
...@@ -64,7 +63,8 @@ def run_nested_model(placement_policy): ...@@ -64,7 +63,8 @@ def run_nested_model(placement_policy):
model.train() model.train()
model_copy.train() model_copy.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA)) pg = ProcessGroup()
set_seed(pg.dp_local_rank())
data_iter = iter(train_dataloader) data_iter = iter(train_dataloader)
data, label = map(lambda x: x.cuda(), next(data_iter)) data, label = map(lambda x: x.cuda(), next(data_iter))
......
...@@ -16,6 +16,7 @@ from colossalai.gemini import ChunkManager, GeminiManager ...@@ -16,6 +16,7 @@ from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer from colossalai.zero import ZeroOptimizer
from colossalai.tensor import ProcessGroup
def init_zero(model, use_chunk, use_zero, placement_policy): def init_zero(model, use_chunk, use_zero, placement_policy):
...@@ -24,7 +25,8 @@ def init_zero(model, use_chunk, use_zero, placement_policy): ...@@ -24,7 +25,8 @@ def init_zero(model, use_chunk, use_zero, placement_policy):
enable_distributed_storage=use_zero, enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy)) init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
return ZeroDDP(model, gemini_manager) pg = ProcessGroup()
return ZeroDDP(model, gemini_manager, pg)
def run_step(model, optim, criterion, data, label): def run_step(model, optim, criterion, data, label):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment