Commit e00a1ca1 authored by Jared Casper's avatar Jared Casper
Browse files

Improve docstrings, destory global memory buffer.

parent a94d0a6e
...@@ -428,18 +428,23 @@ def get_data_parallel_src_rank(): ...@@ -428,18 +428,23 @@ def get_data_parallel_src_rank():
def get_pipeline_model_parallel_first_rank(): def get_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0] return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank(): def get_pipeline_model_parallel_last_rank():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1 last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local] return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank(): def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank() rank_in_pipeline = get_pipeline_model_parallel_rank()
...@@ -448,6 +453,7 @@ def get_pipeline_model_parallel_next_rank(): ...@@ -448,6 +453,7 @@ def get_pipeline_model_parallel_next_rank():
def get_pipeline_model_parallel_prev_rank(): def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that preceeds the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized" "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank() rank_in_pipeline = get_pipeline_model_parallel_rank()
...@@ -471,11 +477,11 @@ def _set_global_memory_buffer(): ...@@ -471,11 +477,11 @@ def _set_global_memory_buffer():
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def get_global_memory_buffer(): def get_global_memory_buffer():
"""Return the global GlobalMemoryBuffer object"""
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized' assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
return _GLOBAL_MEMORY_BUFFER return _GLOBAL_MEMORY_BUFFER
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
...@@ -502,3 +508,5 @@ def destroy_model_parallel(): ...@@ -502,3 +508,5 @@ def destroy_model_parallel():
_MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
global _GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment