Unverified Commit 4d26a67c authored by Amy Yang's avatar Amy Yang Committed by GitHub
Browse files

sync fbcode cp pg initialize (#1177)


Co-authored-by: default avataramyyang <amyyang@meta.com>
parent 8fb39b2a
...@@ -20,8 +20,6 @@ ...@@ -20,8 +20,6 @@
# limitations under the License. # limitations under the License.
"""Model and data parallel groups."""
from typing import List, Optional from typing import List, Optional
import torch import torch
...@@ -41,13 +39,13 @@ _CONTEXT_PARALLEL_GROUP_RANKS = None ...@@ -41,13 +39,13 @@ _CONTEXT_PARALLEL_GROUP_RANKS = None
def initialize_model_parallel( def initialize_model_parallel(
model_parallel_size: int, model_parallel_size_: int,
context_parallel_size: int = 1,
pipeline_length: int = 1, pipeline_length: int = 1,
context_parallel_size: int = 1,
*, *,
model_parallel_backend: Optional[str] = None, model_parallel_backend: Optional[str] = None,
cp_backend: Optional[str] = None,
pipeline_backend: Optional[str] = None, pipeline_backend: Optional[str] = None,
cp_backend: Optional[str] = None,
ddp_backend: Optional[str] = None, ddp_backend: Optional[str] = None,
) -> None: ) -> None:
""" """
...@@ -67,11 +65,28 @@ def initialize_model_parallel( ...@@ -67,11 +65,28 @@ def initialize_model_parallel(
are on the same DGX box. For example if we are using 2 DGX-1 boxes are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box. ranks 8 to 15 belong to the second box.
process groups initialized in the order of MP, CP, PP, DP.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, 2 GPUs to parallelize context(seq len), and 2 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 8 context-parallel group, 8 pipeline model-parallel groups
and 8 data-parallel groups as:
when alternate_pp_config = False,
8 data_parallel groups:
[g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
8 context-parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 pipeline model-parallel groups:
[g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g16], [g7, g15]
""" """
# Get world size and rank. Ensure some consistencies. # Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized() assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
model_parallel_size = int(min(model_parallel_size, world_size)) model_parallel_size = int(min(model_parallel_size_, world_size))
ensure_divisibility(world_size, model_parallel_size) ensure_divisibility(world_size, model_parallel_size)
ensure_divisibility(world_size, context_parallel_size) ensure_divisibility(world_size, context_parallel_size)
ensure_divisibility(world_size, model_parallel_size * pipeline_length * context_parallel_size) ensure_divisibility(world_size, model_parallel_size * pipeline_length * context_parallel_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment