Commit 3ce6a1c2 authored by Abhinav Khattar's avatar Abhinav Khattar
Browse files

add needed helper functions


Signed-off-by: default avatarAbhinav Khattar <aklife97@gmail.com>
parent 035cae2e
...@@ -233,6 +233,11 @@ def initialize_model_parallel( ...@@ -233,6 +233,11 @@ def initialize_model_parallel(
_set_global_memory_buffer() _set_global_memory_buffer()
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \ if _TENSOR_MODEL_PARALLEL_GROUP is None or \
...@@ -454,6 +459,12 @@ def get_virtual_pipeline_model_parallel_world_size(): ...@@ -454,6 +459,12 @@ def get_virtual_pipeline_model_parallel_world_size():
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def set_virtual_pipeline_model_parallel_world_size(world_size):
"""Set the virtual pipeline-parallel world size"""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_tensor_model_parallel_src_rank(): def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank """Calculate the global rank corresponding to the first local rank
in the tensor model parallel group.""" in the tensor model parallel group."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment