"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "6143af2112cd404adca42111daf14397a45f2d7e"
Commit 136cf036 authored by shanmugamr's avatar shanmugamr
Browse files
parents b69e2195 6d417896
...@@ -296,6 +296,12 @@ def set_pipeline_model_parallel_rank(rank): ...@@ -296,6 +296,12 @@ def set_pipeline_model_parallel_rank(rank):
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
def get_tensor_model_parallel_rank(): def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group.""" """Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK global _MPU_TENSOR_MODEL_PARALLEL_RANK
......
...@@ -163,13 +163,6 @@ def model_parallel_cuda_manual_seed(seed): ...@@ -163,13 +163,6 @@ def model_parallel_cuda_manual_seed(seed):
# Data parallel gets the original seed. # Data parallel gets the original seed.
data_parallel_seed = seed data_parallel_seed = seed
if torch.distributed.get_rank() == 0:
print('> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(), get_tensor_model_parallel_rank(),
get_data_parallel_rank(), tensor_model_parallel_seed,
data_parallel_seed), flush=True)
_CUDA_RNG_STATE_TRACKER.reset() _CUDA_RNG_STATE_TRACKER.reset()
# Set the default state. # Set the default state.
torch.cuda.manual_seed(data_parallel_seed) torch.cuda.manual_seed(data_parallel_seed)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment