Commit edca71da authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'core-pipeline' into 'main'

Move pipeline parallel functionality into core with associated changes.

See merge request ADLR/megatron-lm!499
parents 0b44909c 3c92fa93
...@@ -17,7 +17,8 @@ from megatron import print_rank_0 ...@@ -17,7 +17,8 @@ from megatron import print_rank_0
from megatron.core import mpu from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType from megatron.model import GPTModel
from megatron.core.enums import ModelType
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
...@@ -58,12 +58,40 @@ def initialize_model_parallel( ...@@ -58,12 +58,40 @@ def initialize_model_parallel(
Initialize model data parallel groups. Initialize model data parallel groups.
Arguments: Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism. tensor_model_parallel_size (int, default = 1):
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. The number of GPUs to split individual tensors across.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline). pipeline_model_parallel_size (int, default = 1):
pipeline_model_parallel_split_rank: for models with both encoder and decoder, The number of tensor parallel GPU groups to split the
rank in pipeline with split point. Transformer layers across. For example, if
tensor_model_parallel_size is 4 and
pipeline_model_parallel_size is 2, the model will be split
into 2 groups of 4 GPUs.
virtual_pipeline_model_parallel_size (int, optional):
The number of stages that each pipeline group will have,
interleaving as necessary. If None, no interleaving is
performed. For example, if tensor_model_parallel_size is 1,
pipeline_model_parallel_size is 4,
virtual_pipeline_model_parallel_size is 2, and there are
16 transformer layers in the model, the model will be
split into 8 stages with two layers each and each GPU
would get 2 stages as such (layer number starting with 1):
GPU 0: [1, 2] [9, 10]
GPU 1: [3, 4] [11, 12]
GPU 2: [5, 6] [13, 14]
GPU 3: [7, 8] [15, 16]
pipeline_model_parallel_split_rank (int, optional):
For models with both an encoder and decoder, the rank in
pipeline to switch between encoder and decoder (i.e. the
first rank of the decoder). This allows the user to set
the pipeline parallel size of the encoder and decoder
independently. For example, if
pipeline_model_parallel_size is 8 and
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
...@@ -298,8 +326,8 @@ def set_pipeline_model_parallel_rank(rank): ...@@ -298,8 +326,8 @@ def set_pipeline_model_parallel_rank(rank):
def set_pipeline_model_parallel_split_rank(rank): def set_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank.""" """Set pipeline model parallel split rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
def get_tensor_model_parallel_rank(): def get_tensor_model_parallel_rank():
...@@ -318,6 +346,11 @@ def get_pipeline_model_parallel_rank(): ...@@ -318,6 +346,11 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_pipeline_model_parallel_split_rank():
"""Return pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def is_pipeline_first_stage(ignore_virtual=False): def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise.""" """Return True if in the first pipeline model-parallel stage, False otherwise."""
......
from .schedules import get_forward_backward_func
...@@ -13,6 +13,8 @@ import torch.nn.functional as F ...@@ -13,6 +13,8 @@ import torch.nn.functional as F
import torch.nn.init as init import torch.nn.init as init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.cuda.amp import custom_fwd, custom_bwd
from megatron.core.parallel_state import ( from megatron.core.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -214,6 +216,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -214,6 +216,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""See linear_with_grad_accumulation_and_async_allreduce""" """See linear_with_grad_accumulation_and_async_allreduce"""
@staticmethod @staticmethod
@custom_fwd
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, sequence_parallel): async_grad_allreduce, sequence_parallel):
ctx.save_for_backward(input, weight) ctx.save_for_backward(input, weight)
...@@ -243,6 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -243,6 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
@custom_bwd
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, weight = ctx.saved_tensors input, weight = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
...@@ -407,8 +411,8 @@ def linear_with_grad_accumulation_and_async_allreduce( ...@@ -407,8 +411,8 @@ def linear_with_grad_accumulation_and_async_allreduce(
"maximum speedup") "maximum speedup")
linear_with_grad_accumulation_and_async_allreduce.warned = True linear_with_grad_accumulation_and_async_allreduce.warned = True
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
linear_with_grad_accumulation_and_async_allreduce.warned = False linear_with_grad_accumulation_and_async_allreduce.warned = False
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
......
...@@ -20,6 +20,21 @@ def divide(numerator, denominator): ...@@ -20,6 +20,21 @@ def divide(numerator, denominator):
ensure_divisibility(numerator, denominator) ensure_divisibility(numerator, denominator)
return numerator // denominator return numerator // denominator
def get_attr_wrapped_model(model, attr):
"""Get an attribute from a wrapped model"""
if isinstance(model, list):
raise RuntimeError("_get_attr_wrapped_model given a list of models")
while not hasattr(model, attr):
if not hasattr(model, "module"):
raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")
model = model.module
return getattr(model, attr)
def get_model_type(model):
return get_attr_wrapped_model(model, 'model_type')
class GlobalMemoryBuffer: class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations. """Global buffer to avoid dynamic memory allocations.
......
...@@ -8,4 +8,3 @@ from .gpt_model import GPTModel ...@@ -8,4 +8,3 @@ from .gpt_model import GPTModel
from .t5_model import T5Model from .t5_model import T5Model
from .language_model import get_language_model from .language_model import get_language_model
from .module import Float16Module from .module import Float16Module
from .enums import ModelType
...@@ -2,10 +2,6 @@ ...@@ -2,10 +2,6 @@
import enum import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
class LayerType(enum.Enum): class LayerType(enum.Enum):
encoder = 1 encoder = 1
decoder = 2 decoder = 2
......
...@@ -20,7 +20,8 @@ from megatron import get_args, get_retro_args, get_tensorboard_writer ...@@ -20,7 +20,8 @@ from megatron import get_args, get_retro_args, get_tensorboard_writer
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.core import utils as core_utils from megatron.core import utils as core_utils
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
......
...@@ -9,7 +9,8 @@ import torch.nn.functional as F ...@@ -9,7 +9,8 @@ import torch.nn.functional as F
from megatron import get_timers, get_args, core, get_num_microbatches from megatron import get_timers, get_args, core, get_num_microbatches
from .module import MegatronModule from .module import MegatronModule
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
......
...@@ -25,8 +25,8 @@ from megatron import print_rank_last ...@@ -25,8 +25,8 @@ from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.model import GPTModel from megatron.model import GPTModel
from megatron.core.enums import ModelType
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
...@@ -37,7 +37,7 @@ from megatron.utils import check_adlr_autoresume_termination ...@@ -37,7 +37,7 @@ from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm from megatron.utils import calc_params_l2_norm
from megatron.schedules import get_forward_backward_func from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.utils import report_memory from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank from megatron.model.vision.knn_monitor import compute_feature_bank
...@@ -400,6 +400,7 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -400,6 +400,7 @@ def setup_model_and_optimizer(model_provider_func,
return model, optimizer, opt_param_scheduler return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler): model, optimizer, opt_param_scheduler):
"""Single training step.""" """Single training step."""
...@@ -418,8 +419,16 @@ def train_step(forward_step_func, data_iterator, ...@@ -418,8 +419,16 @@ def train_step(forward_step_func, data_iterator,
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
fwd_bwd_timers = timers if args.timing_log_level > 1 else None fwd_bwd_timers = timers if args.timing_log_level > 1 else None
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model, forward_step_func=forward_step_func,
optimizer, fwd_bwd_timers, forward_only=False) data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
dtype=args.params_dtype,
tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size),
grad_scaler=optimizer.scale_loss,
sequence_parallel=args.sequence_parallel,
forward_only=False,
timers=fwd_bwd_timers)
timers('forward-backward').stop() timers('forward-backward').stop()
# Empty unused memory. # Empty unused memory.
...@@ -794,8 +803,15 @@ def evaluate(forward_step_func, ...@@ -794,8 +803,15 @@ def evaluate(forward_step_func,
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
loss_dicts = forward_backward_func( loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None, forward_step_func=forward_step_func,
timers=None, forward_only=True) data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
dtype=args.params_dtype,
tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size),
sequence_parallel=args.sequence_parallel,
forward_only=True,
timers=None)
# Empty unused memory # Empty unused memory
if args.empty_unused_memory_level >= 1: if args.empty_unused_memory_level >= 1:
......
...@@ -11,8 +11,9 @@ from megatron import get_args ...@@ -11,8 +11,9 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, ModelType from megatron.model import BertModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
...@@ -9,8 +9,9 @@ from megatron import print_rank_0 ...@@ -9,8 +9,9 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType from megatron.model import GPTModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
...@@ -13,9 +13,9 @@ from megatron import get_args ...@@ -13,9 +13,9 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron.core import mpu from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.data.biencoder_dataset_utils import get_ict_batch from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ModelType
from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
...@@ -10,7 +10,8 @@ from megatron import get_timers ...@@ -10,7 +10,8 @@ from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.model import GPTModel, ModelType from megatron.core.enums import ModelType
from megatron.model import GPTModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from tools.retro.pretraining.retro_dataset import get_retro_datasets from tools.retro.pretraining.retro_dataset import get_retro_datasets
......
...@@ -12,8 +12,9 @@ from megatron import ( ...@@ -12,8 +12,9 @@ from megatron import (
print_rank_0 print_rank_0
) )
from megatron.core import tensor_parallel from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model, ModelType from megatron.model import T5Model
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
......
...@@ -6,8 +6,8 @@ import torch ...@@ -6,8 +6,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from functools import partial
from megatron import get_args, get_timers, print_rank_0 from megatron import get_args, get_timers, print_rank_0
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType
from megatron.model.vision.classification import VitClassificationModel from megatron.model.vision.classification import VitClassificationModel
from megatron.model.vision.classification import MitClassificationModel from megatron.model.vision.classification import MitClassificationModel
from megatron.training import pretrain from megatron.training import pretrain
......
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import torch.distributed as dist import torch.distributed as dist
from functools import partial from functools import partial
from megatron import get_args, get_timers, print_rank_0 from megatron import get_args, get_timers, print_rank_0
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.dino import DINOPretrainModel from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
...@@ -15,7 +16,6 @@ from megatron.utils import average_losses_across_data_parallel_group, unwrap_mod ...@@ -15,7 +16,6 @@ from megatron.utils import average_losses_across_data_parallel_group, unwrap_mod
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment