"fair_dev/testing/testing.py" did not exist on "fee979d9b0617c219a122d51ab97e962d7d4d694"
Unverified Commit 0af41aee authored by Amy Yang's avatar Amy Yang Committed by GitHub
Browse files

add context parallel group init to mp init (#1174)


Co-authored-by: default avataramyyang <amyyang@meta.com>
parent 9a173bf2
...@@ -34,17 +34,21 @@ _MODEL_PARALLEL_GROUP = None ...@@ -34,17 +34,21 @@ _MODEL_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
# Pipeline parallel group that the current rank belongs to. # Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP = None _PIPELINE_PARALLEL_GROUP = None
_PIPELINE_PARALLEL_RANKS = None _PIPELINE_PARALLEL_RANKS = None
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_GROUP_RANKS = None
def initialize_model_parallel( def initialize_model_parallel(
model_parallel_size_: int, model_parallel_size: int,
context_parallel_size: int = 1,
pipeline_length: int = 1, pipeline_length: int = 1,
*, *,
model_parallel_backend: Optional[str] = None, model_parallel_backend: Optional[str] = None,
cp_backend: Optional[str] = None,
pipeline_backend: Optional[str] = None, pipeline_backend: Optional[str] = None,
ddp_backend: Optional[str] = None ddp_backend: Optional[str] = None,
) -> None: ) -> None:
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
...@@ -67,19 +71,21 @@ def initialize_model_parallel( ...@@ -67,19 +71,21 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
model_parallel_size = int(min(model_parallel_size_, world_size)) model_parallel_size = int(min(model_parallel_size, world_size))
ensure_divisibility(world_size, model_parallel_size) ensure_divisibility(world_size, model_parallel_size)
ensure_divisibility(world_size, model_parallel_size * pipeline_length) ensure_divisibility(world_size, context_parallel_size)
ensure_divisibility(world_size, model_parallel_size * pipeline_length * context_parallel_size)
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
data_parallel_size = int(world_size / (model_parallel_size * pipeline_length)) data_parallel_size = int(world_size / (model_parallel_size * pipeline_length * context_parallel_size))
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> initializing model parallel with size {}".format(model_parallel_size_)) print("> initializing model parallel with size {}".format(model_parallel_size))
print("> initializing ddp with size {}".format(data_parallel_size)) print("> initializing context parallel with size {}".format(context_parallel_size))
print("> initializing pipeline with size {}".format(pipeline_length)) print("> initializing pipeline with size {}".format(pipeline_length))
print("> initializing ddp with size {}".format(data_parallel_size))
groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size) groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, context_parallel_size, model_parallel_size)
found = torch.where(groups == rank) found = torch.where(groups == rank)
assert all(len(x) == 1 for x in found) assert all(len(x) == 1 for x in found)
...@@ -88,41 +94,81 @@ def initialize_model_parallel( ...@@ -88,41 +94,81 @@ def initialize_model_parallel(
# Build the data parallel groups. # Build the data parallel groups.
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
for j in range(pipeline_length): for i in range(pipeline_length):
for j in range(context_parallel_size):
for k in range(model_parallel_size): for k in range(model_parallel_size):
group = torch.distributed.new_group(groups[:, j, k].tolist(), backend=ddp_backend) group = torch.distributed.new_group(groups[:, i, j, k].tolist(), backend=ddp_backend)
if j == found[1] and k == found[2]: if i == found[1] and j == found[2] and k == found[3]:
_DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GROUP = group
# Build the model parallel groups. # Build the model parallel groups.
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" assert _MODEL_PARALLEL_GROUP is None, "Model parallel group is already initialized"
for i in range(data_parallel_size): for i in range(data_parallel_size):
for j in range(pipeline_length): for j in range(pipeline_length):
group = torch.distributed.new_group(groups[i, j, :].tolist(), backend=model_parallel_backend) for k in range(context_parallel_size):
if i == found[0] and j == found[1]: group = torch.distributed.new_group(groups[i, j, k, :].tolist(), backend=model_parallel_backend)
if i == found[0] and j == found[1] and k == found[2]:
_MODEL_PARALLEL_GROUP = group _MODEL_PARALLEL_GROUP = group
# Build the pipeline parallel groups.
global _PIPELINE_PARALLEL_GROUP global _PIPELINE_PARALLEL_GROUP
assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized"
global _PIPELINE_PARALLEL_RANKS global _PIPELINE_PARALLEL_RANKS
assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized" assert _PIPELINE_PARALLEL_GROUP is None, "Pipeline parallel group is already initialized"
for i in range(data_parallel_size): for i in range(data_parallel_size):
for j in range(context_parallel_size):
for k in range(model_parallel_size): for k in range(model_parallel_size):
ranks = groups[i, :, k].tolist() ranks = groups[i, :, j, k].tolist()
group = torch.distributed.new_group(ranks, backend=pipeline_backend) group = torch.distributed.new_group(ranks, backend=pipeline_backend)
if i == found[0] and k == found[2]: if i == found[0] and j == found[2] and k == found[3]:
_PIPELINE_PARALLEL_GROUP = group _PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks _PIPELINE_PARALLEL_RANKS = ranks
# Build the context parallel groups.
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_GROUP_RANKS
assert (
_CONTEXT_PARALLEL_GROUP is None
), "Context parallelism is already initialized."
for i in range(data_parallel_size):
for j in range(pipeline_length):
for k in range(model_parallel_size):
ranks = groups[i, j, :, k].tolist()
group = torch.distributed.new_group(ranks, backend=cp_backend)
if i == found[0] and j == found[1] and k == found[3]:
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_GROUP_RANKS = ranks
def model_parallel_is_initialized() -> bool: def model_parallel_is_initialized() -> bool:
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None or _PIPELINE_PARALLEL_GROUP is None: if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None or _PIPELINE_PARALLEL_GROUP is None or _CONTEXT_PARALLEL_GROUP is None:
return False return False
return True return True
def get_context_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the context parallel group the caller rank belongs to."""
assert (
_CONTEXT_PARALLEL_GROUP is not None
), "context parallel group is not initialized"
return _CONTEXT_PARALLEL_GROUP
def get_context_parallel_world_size() -> int:
"""Return world size for the context parallel group."""
return torch.distributed.get_world_size(group=get_context_parallel_group())
def get_context_parallel_rank() -> int:
"""Return my rank for the context parallel group."""
return torch.distributed.get_rank(group=get_context_parallel_group())
def get_model_parallel_group() -> torch.distributed.ProcessGroup: def get_model_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the model parallel group the caller rank belongs to.""" """Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
...@@ -179,10 +225,16 @@ def destroy_model_parallel() -> None: ...@@ -179,10 +225,16 @@ def destroy_model_parallel() -> None:
"""Set the groups to none.""" """Set the groups to none."""
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None _MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_GROUP global _PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP = None _PIPELINE_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_RANKS global _PIPELINE_PARALLEL_RANKS
_PIPELINE_PARALLEL_RANKS = None _PIPELINE_PARALLEL_RANKS = None
global _CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP = None
global _CONTEXT_PARALLEL_GROUP_RANKS
_CONTEXT_PARALLEL_GROUP_RANKS = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment