Unverified Commit 0af41aee authored by Amy Yang's avatar Amy Yang Committed by GitHub
Browse files

add context parallel group init to mp init (#1174)


Co-authored-by: default avataramyyang <amyyang@meta.com>
parent 9a173bf2
......@@ -34,17 +34,21 @@ _MODEL_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP = None
# Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP = None
_PIPELINE_PARALLEL_RANKS = None
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_GROUP_RANKS = None
def initialize_model_parallel(
model_parallel_size_: int,
model_parallel_size: int,
context_parallel_size: int = 1,
pipeline_length: int = 1,
*,
model_parallel_backend: Optional[str] = None,
cp_backend: Optional[str] = None,
pipeline_backend: Optional[str] = None,
ddp_backend: Optional[str] = None
ddp_backend: Optional[str] = None,
) -> None:
"""
Initialize model data parallel groups.
......@@ -67,19 +71,21 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
model_parallel_size = int(min(model_parallel_size_, world_size))
model_parallel_size = int(min(model_parallel_size, world_size))
ensure_divisibility(world_size, model_parallel_size)
ensure_divisibility(world_size, model_parallel_size * pipeline_length)
ensure_divisibility(world_size, context_parallel_size)
ensure_divisibility(world_size, model_parallel_size * pipeline_length * context_parallel_size)
rank = torch.distributed.get_rank()
data_parallel_size = int(world_size / (model_parallel_size * pipeline_length))
data_parallel_size = int(world_size / (model_parallel_size * pipeline_length * context_parallel_size))
if torch.distributed.get_rank() == 0:
print("> initializing model parallel with size {}".format(model_parallel_size_))
print("> initializing ddp with size {}".format(data_parallel_size))
print("> initializing model parallel with size {}".format(model_parallel_size))
print("> initializing context parallel with size {}".format(context_parallel_size))
print("> initializing pipeline with size {}".format(pipeline_length))
print("> initializing ddp with size {}".format(data_parallel_size))
groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size)
groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, context_parallel_size, model_parallel_size)
found = torch.where(groups == rank)
assert all(len(x) == 1 for x in found)
......@@ -88,41 +94,81 @@ def initialize_model_parallel(
# Build the data parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
for j in range(pipeline_length):
for k in range(model_parallel_size):
group = torch.distributed.new_group(groups[:, j, k].tolist(), backend=ddp_backend)
if j == found[1] and k == found[2]:
_DATA_PARALLEL_GROUP = group
for i in range(pipeline_length):
for j in range(context_parallel_size):
for k in range(model_parallel_size):
group = torch.distributed.new_group(groups[:, i, j, k].tolist(), backend=ddp_backend)
if i == found[1] and j == found[2] and k == found[3]:
_DATA_PARALLEL_GROUP = group
# Build the model parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
assert _MODEL_PARALLEL_GROUP is None, "Model parallel group is already initialized"
for i in range(data_parallel_size):
for j in range(pipeline_length):
group = torch.distributed.new_group(groups[i, j, :].tolist(), backend=model_parallel_backend)
if i == found[0] and j == found[1]:
_MODEL_PARALLEL_GROUP = group
for k in range(context_parallel_size):
group = torch.distributed.new_group(groups[i, j, k, :].tolist(), backend=model_parallel_backend)
if i == found[0] and j == found[1] and k == found[2]:
_MODEL_PARALLEL_GROUP = group
# Build the pipeline parallel groups.
global _PIPELINE_PARALLEL_GROUP
assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized"
global _PIPELINE_PARALLEL_RANKS
assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized"
assert _PIPELINE_PARALLEL_GROUP is None, "Pipeline parallel group is already initialized"
for i in range(data_parallel_size):
for j in range(context_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, j, k].tolist()
group = torch.distributed.new_group(ranks, backend=pipeline_backend)
if i == found[0] and j == found[2] and k == found[3]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks
# Build the context parallel groups.
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_GROUP_RANKS
assert (
_CONTEXT_PARALLEL_GROUP is None
), "Context parallelism is already initialized."
for i in range(data_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, k].tolist()
group = torch.distributed.new_group(ranks, backend=pipeline_backend)
if i == found[0] and k == found[2]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks
for j in range(pipeline_length):
for k in range(model_parallel_size):
ranks = groups[i, j, :, k].tolist()
group = torch.distributed.new_group(ranks, backend=cp_backend)
if i == found[0] and j == found[1] and k == found[3]:
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_GROUP_RANKS = ranks
def model_parallel_is_initialized() -> bool:
"""Check if model and data parallel groups are initialized."""
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None or _PIPELINE_PARALLEL_GROUP is None:
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None or _PIPELINE_PARALLEL_GROUP is None or _CONTEXT_PARALLEL_GROUP is None:
return False
return True
def get_context_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the context parallel group the caller rank belongs to."""
assert (
_CONTEXT_PARALLEL_GROUP is not None
), "context parallel group is not initialized"
return _CONTEXT_PARALLEL_GROUP
def get_context_parallel_world_size() -> int:
"""Return world size for the context parallel group."""
return torch.distributed.get_world_size(group=get_context_parallel_group())
def get_context_parallel_rank() -> int:
"""Return my rank for the context parallel group."""
return torch.distributed.get_rank(group=get_context_parallel_group())
def get_model_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
......@@ -179,10 +225,16 @@ def destroy_model_parallel() -> None:
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_RANKS
_PIPELINE_PARALLEL_RANKS = None
global _CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP = None
global _CONTEXT_PARALLEL_GROUP_RANKS
_CONTEXT_PARALLEL_GROUP_RANKS = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment