Commit edca71da authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'core-pipeline' into 'main'

Move pipeline parallel functionality into core with associated changes.

See merge request ADLR/megatron-lm!499
parents 0b44909c 3c92fa93
......@@ -17,7 +17,8 @@ from megatron import print_rank_0
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType
from megatron.model import GPTModel
from megatron.core.enums import ModelType
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
......@@ -58,12 +58,40 @@ def initialize_model_parallel(
Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
tensor_model_parallel_size (int, default = 1):
The number of GPUs to split individual tensors across.
pipeline_model_parallel_size (int, default = 1):
The number of tensor parallel GPU groups to split the
Transformer layers across. For example, if
tensor_model_parallel_size is 4 and
pipeline_model_parallel_size is 2, the model will be split
into 2 groups of 4 GPUs.
virtual_pipeline_model_parallel_size (int, optional):
The number of stages that each pipeline group will have,
interleaving as necessary. If None, no interleaving is
performed. For example, if tensor_model_parallel_size is 1,
pipeline_model_parallel_size is 4,
virtual_pipeline_model_parallel_size is 2, and there are
16 transformer layers in the model, the model will be
split into 8 stages with two layers each and each GPU
would get 2 stages as such (layer number starting with 1):
GPU 0: [1, 2] [9, 10]
GPU 1: [3, 4] [11, 12]
GPU 2: [5, 6] [13, 14]
GPU 3: [7, 8] [15, 16]
pipeline_model_parallel_split_rank (int, optional):
For models with both an encoder and decoder, the rank in
pipeline to switch between encoder and decoder (i.e. the
first rank of the decoder). This allows the user to set
the pipeline parallel size of the encoder and decoder
independently. For example, if
pipeline_model_parallel_size is 8 and
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
......@@ -298,8 +326,8 @@ def set_pipeline_model_parallel_rank(rank):
def set_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
def get_tensor_model_parallel_rank():
......@@ -318,6 +346,11 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def get_pipeline_model_parallel_split_rank():
"""Return pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
......
from .schedules import get_forward_backward_func
......@@ -13,6 +13,8 @@ import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from torch.cuda.amp import custom_fwd, custom_bwd
from megatron.core.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -214,6 +216,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""See linear_with_grad_accumulation_and_async_allreduce"""
@staticmethod
@custom_fwd
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, sequence_parallel):
ctx.save_for_backward(input, weight)
......@@ -243,6 +246,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
......@@ -407,8 +411,8 @@ def linear_with_grad_accumulation_and_async_allreduce(
"maximum speedup")
linear_with_grad_accumulation_and_async_allreduce.warned = True
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
linear_with_grad_accumulation_and_async_allreduce.warned = False
class ColumnParallelLinear(torch.nn.Module):
......
......@@ -20,6 +20,21 @@ def divide(numerator, denominator):
ensure_divisibility(numerator, denominator)
return numerator // denominator
def get_attr_wrapped_model(model, attr):
"""Get an attribute from a wrapped model"""
if isinstance(model, list):
raise RuntimeError("_get_attr_wrapped_model given a list of models")
while not hasattr(model, attr):
if not hasattr(model, "module"):
raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")
model = model.module
return getattr(model, attr)
def get_model_type(model):
return get_attr_wrapped_model(model, 'model_type')
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
......
......@@ -8,4 +8,3 @@ from .gpt_model import GPTModel
from .t5_model import T5Model
from .language_model import get_language_model
from .module import Float16Module
from .enums import ModelType
......@@ -2,10 +2,6 @@
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
class LayerType(enum.Enum):
encoder = 1
decoder = 2
......
......@@ -20,7 +20,8 @@ from megatron import get_args, get_retro_args, get_tensorboard_writer
from megatron.core import parallel_state
from megatron.core import tensor_parallel
from megatron.core import utils as core_utils
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
......
......@@ -9,7 +9,8 @@ import torch.nn.functional as F
from megatron import get_timers, get_args, core, get_num_microbatches
from .module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
......
......@@ -25,8 +25,8 @@ from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.model import GPTModel
from megatron.core.enums import ModelType
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
......@@ -37,7 +37,7 @@ from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.schedules import get_forward_backward_func
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank
......@@ -400,6 +400,7 @@ def setup_model_and_optimizer(model_provider_func,
return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler):
"""Single training step."""
......@@ -418,8 +419,16 @@ def train_step(forward_step_func, data_iterator,
forward_backward_func = get_forward_backward_func()
fwd_bwd_timers = timers if args.timing_log_level > 1 else None
losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, fwd_bwd_timers, forward_only=False)
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
dtype=args.params_dtype,
tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size),
grad_scaler=optimizer.scale_loss,
sequence_parallel=args.sequence_parallel,
forward_only=False,
timers=fwd_bwd_timers)
timers('forward-backward').stop()
# Empty unused memory.
......@@ -794,8 +803,15 @@ def evaluate(forward_step_func,
forward_backward_func = get_forward_backward_func()
loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True)
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
dtype=args.params_dtype,
tensor_shape=(args.seq_length, args.micro_batch_size, args.hidden_size),
sequence_parallel=args.sequence_parallel,
forward_only=True,
timers=None)
# Empty unused memory
if args.empty_unused_memory_level >= 1:
......
......@@ -11,8 +11,9 @@ from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, ModelType
from megatron.model import BertModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -9,8 +9,9 @@ from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType
from megatron.model import GPTModel
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -13,9 +13,9 @@ from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ModelType
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -10,7 +10,8 @@ from megatron import get_timers
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.model import GPTModel, ModelType
from megatron.core.enums import ModelType
from megatron.model import GPTModel
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from tools.retro.pretraining.retro_dataset import get_retro_datasets
......
......@@ -12,8 +12,9 @@ from megatron import (
print_rank_0
)
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model, ModelType
from megatron.model import T5Model
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -6,8 +6,8 @@ import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, print_rank_0
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType
from megatron.model.vision.classification import VitClassificationModel
from megatron.model.vision.classification import MitClassificationModel
from megatron.training import pretrain
......
......@@ -7,6 +7,7 @@ import numpy as np
import torch.distributed as dist
from functools import partial
from megatron import get_args, get_timers, print_rank_0
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
......@@ -15,7 +16,6 @@ from megatron.utils import average_losses_across_data_parallel_group, unwrap_mod
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment