Unverified Commit 7487215b authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[ColoTensor] add independent process group (#1179)

parent 26ba8727
...@@ -7,8 +7,10 @@ from .dist_spec_mgr import DistSpecManager ...@@ -7,8 +7,10 @@ from .dist_spec_mgr import DistSpecManager
from .param_op_hook import ParamOpHook, ParamOpHookManager from .param_op_hook import ParamOpHook, ParamOpHookManager
from .chunk import ChunkManager, TensorState from .chunk import ChunkManager, TensorState
from . import distspec from . import distspec
from .process_group import ProcessGroup
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState' 'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState',
'ProcessGroup'
] ]
import torch
from typing import List, Optional
class ProcessGroup:
"""
Process Group contains group partition for Tensor Parallel and Data Parallel.
WARNING, the ProcessGroup must be used after torch.distributed.initialize()
args:
rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group.
backend: str, the backend of the process group.
tp_degree: Optional[int], tensor parallelism degree, default None means 1
dp_degree: Optional[int], data parallelism degree, default None means len(ranks)
"""
def __init__(self,
rank: int,
ranks: List[int],
backend: str = 'nccl',
tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None) -> None:
self._rank = rank
self._rank_list = ranks
self._backend = backend
self._world_size = len(self._rank_list)
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
if dp_degree is None and tp_degree is None:
self._dp_degree = self._world_size
self._tp_degree = 1
if dp_degree and not tp_degree:
self._dp_degree = dp_degree
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
self._tp_degree = self._world_size / dp_degree
if not dp_degree and tp_degree:
self._tp_degree = tp_degree
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
self._dp_degree = self._world_size / tp_degree
self._tp_rank_list = []
self._dp_rank_list = []
for rank_id in range(self._world_size):
# rank_id and self._rank in the same tp group
if rank_id % self._tp_degree == self._rank % self._tp_degree:
self._dp_rank_list.append(rank_id)
if rank_id // self._tp_degree == self._rank // self._tp_degree:
self._tp_rank_list.append(rank_id)
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend=backend)
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend=backend)
def world_size(self):
return self._world_size
def dp_world_size(self):
return len(self._dp_rank_list)
def tp_world_size(self):
return len(self._tp_rank_list)
def dp_process_group(self):
return self._dp_process_group
def tp_process_group(self):
return self._tp_process_group
...@@ -10,7 +10,7 @@ from colossalai.utils.cuda import get_current_device ...@@ -10,7 +10,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \ from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ComputeSpec, ColoTensor, DistSpecManager ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColoOptimizer from colossalai.nn.optimizer import ColoOptimizer
...@@ -18,34 +18,30 @@ from functools import partial ...@@ -18,34 +18,30 @@ from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_equal, tensor_shard_equal, set_seed
def init_1d_row_linear(weight): def init_1d_row_linear(weight, pg: ProcessGroup):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col_linear(weight): def init_1d_col_linear(weight, pg):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_row_embedding(weight): def init_1d_row_embedding(weight, pg):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
def init_1d_col_embedding(weight): def init_1d_col_embedding(weight, pg):
spec = TensorSpec( spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D))
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_tensor_spec(spec) weight.set_tensor_spec(spec)
...@@ -69,6 +65,9 @@ def run_1d_hybrid_tp(model_name): ...@@ -69,6 +65,9 @@ def run_1d_hybrid_tp(model_name):
for p1, p2 in zip(model.parameters(), model_torch.parameters()): for p1, p2 in zip(model.parameters(), model_torch.parameters()):
p2.data.copy_(p1.data) p2.data.copy_(p1.data)
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
if 'bert' == model_name: if 'bert' == model_name:
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if not isinstance(p, ColoTensor): if not isinstance(p, ColoTensor):
...@@ -76,29 +75,29 @@ def run_1d_hybrid_tp(model_name): ...@@ -76,29 +75,29 @@ def run_1d_hybrid_tp(model_name):
# print(name) # print(name)
# num_class = type_vocab_size = 2 | (8, 2) # num_class = type_vocab_size = 2 | (8, 2)
if 'classifier' in name and 'weight' in name: if 'classifier' in name and 'weight' in name:
init_1d_row_linear(p) init_1d_row_linear(p, pg)
# num_class = vocab_size = 30524 | (30524, 8) # num_class = vocab_size = 30524 | (30524, 8)
if 'word_embeddings' in name and 'weight' in name: if 'word_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p) init_1d_row_embedding(p, pg)
# num_class = seq_len = 512 | (512, 8) # num_class = seq_len = 512 | (512, 8)
if 'position_embeddings' in name and 'weight' in name: if 'position_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p) init_1d_row_embedding(p, pg)
# num_class = type_vocab_size = 2 | (2, 8) # num_class = type_vocab_size = 2 | (2, 8)
if 'token_type_embeddings' in name and 'weight' in name: if 'token_type_embeddings' in name and 'weight' in name:
init_1d_col_embedding(p) init_1d_col_embedding(p, pg)
elif "simple_net" == model_name: elif "simple_net" == model_name:
# A naive way to set spec for all weights in Linear # A naive way to set spec for all weights in Linear
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if not isinstance(p, ColoTensor): if not isinstance(p, ColoTensor):
continue continue
if 'embed' in name and 'weight' in name: if 'embed' in name and 'weight' in name:
init_1d_col_embedding(p) init_1d_col_embedding(p, pg)
if 'proj1' in name and ('weight' in name or 'bias' in name): if 'proj1' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p) init_1d_col_linear(p, pg)
if 'proj2' in name and 'weight' in name: if 'proj2' in name and 'weight' in name:
init_1d_row_linear(p) init_1d_row_linear(p, pg)
if 'classifier' in name and ('weight' in name or 'bias' in name): if 'classifier' in name and ('weight' in name or 'bias' in name):
init_1d_col_linear(p) init_1d_col_linear(p, pg)
model = model.cuda() model = model.cuda()
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1) colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
...@@ -112,8 +111,8 @@ def run_1d_hybrid_tp(model_name): ...@@ -112,8 +111,8 @@ def run_1d_hybrid_tp(model_name):
data = data.to(get_current_device()) data = data.to(get_current_device())
label = label.to(get_current_device()) label = label.to(get_current_device())
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D)) torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
# Bcast rank0 data to all processes # Bcast rank0 data to all processes
if criterion: if criterion:
output = model(data) output = model(data)
...@@ -221,6 +220,10 @@ def run_1d_row_tp(model_name: str): ...@@ -221,6 +220,10 @@ def run_1d_row_tp(model_name: str):
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
rank = gpc.get_local_rank(ParallelMode.GLOBAL)
world_size = gpc.get_world_size(ParallelMode.GLOBAL)
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
set_seed(1) set_seed(1)
if rank == 0: if rank == 0:
model_torch = model_builder(checkpoint=True) model_torch = model_builder(checkpoint=True)
...@@ -230,9 +233,9 @@ def run_1d_row_tp(model_name: str): ...@@ -230,9 +233,9 @@ def run_1d_row_tp(model_name: str):
if not isinstance(p, ColoTensor): if not isinstance(p, ColoTensor):
continue continue
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name: if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
init_1d_row_linear(p) init_1d_row_linear(p, pg)
if 'embed' in name and 'weight' in name: if 'embed' in name and 'weight' in name:
init_1d_row_embedding(p) init_1d_row_embedding(p, pg)
model = model.cuda() model = model.cuda()
...@@ -330,10 +333,11 @@ def run_pretrain_load_dist(rank, world_size, port): ...@@ -330,10 +333,11 @@ def run_pretrain_load_dist(rank, world_size, port):
# The test case has to download huggingface pretrained models from the internet # The test case has to download huggingface pretrained models from the internet
# So we manually trigger the test. # So we manually trigger the test.
@pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def _test_pretrain_load(world_size): def test_pretrain_load(world_size):
run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port()) run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
...@@ -342,4 +346,4 @@ if __name__ == '__main__': ...@@ -342,4 +346,4 @@ if __name__ == '__main__':
# test_model_parameters() # test_model_parameters()
# test_colo_optimizer() # test_colo_optimizer()
# test_model(4) # test_model(4)
_test_pretrain_load(4) test_pretrain_load(4)
...@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc ...@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.tensor import distspec, TensorSpec, ColoTensor from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from functools import partial from functools import partial
...@@ -21,14 +21,6 @@ def test_tensor_indexing(): ...@@ -21,14 +21,6 @@ def test_tensor_indexing():
assert allclose(torch_t[:, 1], colo_t[:, 1]) assert allclose(torch_t[:, 1], colo_t[:, 1])
@pytest.mark.skip
# FIXME(ver217): support lazy init
def test_lazy_init_tensor():
lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True)
assert lazy_t._torch_tensor.numel() == 0
assert lazy_t.numel() == 6 == lazy_t.torch_tensor().numel()
def test_wrapped_tensor_func(): def test_wrapped_tensor_func():
t_ref = torch.randn(4, 5) t_ref = torch.randn(4, 5)
t = ColoTensor.from_torch_tensor(t_ref.clone()) t = ColoTensor.from_torch_tensor(t_ref.clone())
...@@ -62,10 +54,12 @@ def test_operand(): ...@@ -62,10 +54,12 @@ def test_operand():
def _run_view(world_size): def _run_view(world_size):
t_ref = torch.randn(4, 5) t_ref = torch.randn(4, 5)
rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size)))
assert pg.dp_world_size() == world_size, f"{pg.dp_world_size()} vs {world_size}"
t = ColoTensor.from_torch_tensor( t = ColoTensor.from_torch_tensor(
t_ref, t_ref,
TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], TensorSpec(distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])))
num_partitions=[world_size])))
assert t.size_global()[0] == 4 * world_size assert t.size_global()[0] == 4 * world_size
assert t.size_global(1) == 5 assert t.size_global(1) == 5
...@@ -81,8 +75,10 @@ def _run_view(world_size): ...@@ -81,8 +75,10 @@ def _run_view(world_size):
def _run_tensor_shard_init(world_size): def _run_tensor_shard_init(world_size):
t_ref = torch.randn(4, 5) t_ref = torch.randn(4, 5)
print(gpc.get_group(ParallelMode.DATA).size())
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size]) rank = gpc.get_global_rank()
pg = ProcessGroup(rank, list(range(world_size)))
shard_spec = distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])
tensor_spec = TensorSpec(shard_spec) tensor_spec = TensorSpec(shard_spec)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate())) t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment