Unverified Commit 8fb39b2a authored by Amy Yang's avatar Amy Yang Committed by GitHub
Browse files

add get_cp_ranks to model_parallel initialize (#1176)


Co-authored-by: default avataramyyang <amyyang@meta.com>
parent 0af41aee
......@@ -159,6 +159,12 @@ def get_context_parallel_group() -> torch.distributed.ProcessGroup:
return _CONTEXT_PARALLEL_GROUP
def get_context_parallel_ranks() -> List[int]:
"""Return context parallel ranks for the context parallel group."""
assert _CONTEXT_PARALLEL_GROUP_RANKS is not None, "context parallel group is not initialized"
return _CONTEXT_PARALLEL_GROUP_RANKS
def get_context_parallel_world_size() -> int:
"""Return world size for the context parallel group."""
return torch.distributed.get_world_size(group=get_context_parallel_group())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment