Commit ee084835 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'add_helper_fns' into 'main'

Add parallel_state helper functions

See merge request ADLR/megatron-lm!564
parents 9c5e287d 3ce6a1c2
......@@ -233,6 +233,11 @@ def initialize_model_parallel(
_set_global_memory_buffer()
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
......@@ -454,6 +459,12 @@ def get_virtual_pipeline_model_parallel_world_size():
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def set_virtual_pipeline_model_parallel_world_size(world_size):
"""Set the virtual pipeline-parallel world size"""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment