Commit 5c45db4a authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Initial implementation of pipelined text generation

parent caa9dca5
...@@ -34,7 +34,8 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle ...@@ -34,7 +34,8 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import get_tensor_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_src_rank from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
......
...@@ -38,6 +38,7 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None ...@@ -38,6 +38,7 @@ _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
_PIPELINE_GLOBAL_RANKS = None
def is_unitialized(): def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization""" """Useful for code segments that may be accessed with or without mpu initialization"""
...@@ -131,6 +132,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -131,6 +132,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the pipeline model-parallel groups and embedding groups # Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group). # (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \ assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized' 'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP global _EMBEDDING_GROUP
...@@ -142,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -142,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if rank in ranks: if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group _PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between # Setup embedding group (to exchange gradients between
# first and last stages). # first and last stages).
if len(ranks) > 1: if len(ranks) > 1:
...@@ -265,21 +268,22 @@ def is_pipeline_last_stage(): ...@@ -265,21 +268,22 @@ def is_pipeline_last_stage():
def get_tensor_model_parallel_src_rank(): def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank """Calculate the global rank corresponding to the first local rank
in the tensor model parallel group.""" in the tensor model parallel group."""
global_rank = torch.distributed.get_rank() global_rank = torch.distributed.get_rank()
local_world_size = get_tensor_model_parallel_world_size() local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_src_rank(): def get_pipeline_model_parallel_first_rank():
"""Calculate the global rank corresponding to a local rank assert _PIPELINE_GLOBAL_RANKS is not None, \
in the pipeline model parallel group.""" "Pipeline parallel group is not initialized"
global_rank = torch.distributed.get_rank() return _PIPELINE_GLOBAL_RANKS[0]
global_world_size = torch.distributed.get_world_size()
local_world_size = get_pipeline_model_parallel_world_size()
return global_rank % (global_world_size // local_world_size)
def get_data_parallel_world_size(): def get_data_parallel_world_size():
"""Return world size for the data parallel group.""" """Return world size for the data parallel group."""
......
This diff is collapsed.
...@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ...@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelLastStage, GPT2ModelIntermediateStage
from megatron.training import get_model from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file from megatron.text_generation_utils import generate_samples_input_from_file
...@@ -36,7 +37,19 @@ def model_provider(): ...@@ -36,7 +37,19 @@ def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=False) args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
num_tokentypes=0, parallel_output=False)
else:
model = GPT2ModelIntermediateStage(
num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=False)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment