"src/vscode:/vscode.git/clone" did not exist on "4a60b45d4c4c80fa934d33e51edfcd29f9795470"
Unverified Commit 4d26a67c authored by Amy Yang's avatar Amy Yang Committed by GitHub
Browse files

sync fbcode cp pg initialize (#1177)


Co-authored-by: default avataramyyang <amyyang@meta.com>
parent 8fb39b2a
......@@ -20,8 +20,6 @@
# limitations under the License.
"""Model and data parallel groups."""
from typing import List, Optional
import torch
......@@ -41,13 +39,13 @@ _CONTEXT_PARALLEL_GROUP_RANKS = None
def initialize_model_parallel(
model_parallel_size: int,
context_parallel_size: int = 1,
model_parallel_size_: int,
pipeline_length: int = 1,
context_parallel_size: int = 1,
*,
model_parallel_backend: Optional[str] = None,
cp_backend: Optional[str] = None,
pipeline_backend: Optional[str] = None,
cp_backend: Optional[str] = None,
ddp_backend: Optional[str] = None,
) -> None:
"""
......@@ -67,11 +65,28 @@ def initialize_model_parallel(
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
process groups initialized in the order of MP, CP, PP, DP.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, 2 GPUs to parallelize context(seq len), and 2 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 8 context-parallel group, 8 pipeline model-parallel groups
and 8 data-parallel groups as:
when alternate_pp_config = False,
8 data_parallel groups:
[g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
8 context-parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 pipeline model-parallel groups:
[g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g16], [g7, g15]
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
model_parallel_size = int(min(model_parallel_size, world_size))
model_parallel_size = int(min(model_parallel_size_, world_size))
ensure_divisibility(world_size, model_parallel_size)
ensure_divisibility(world_size, context_parallel_size)
ensure_divisibility(world_size, model_parallel_size * pipeline_length * context_parallel_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment