Commit 3c92fa93 authored by Jared Casper's avatar Jared Casper
Browse files

Move pipeline parallel functionality into core with associated changes.

parent 0b44909c
......@@ -6,13 +6,13 @@ import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, print_rank_0, print_rank_last
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.inpainting import VitInpaintingModel
from megatron.model.vision.inpainting import MitInpaintingModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from tasks.vision.metrics import SSIM, PSNR
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
......
......@@ -10,9 +10,9 @@ from megatron import get_args, get_num_microbatches
from megatron import print_rank_0
from megatron import get_timers
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import ModelType
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
......
......@@ -19,8 +19,8 @@ from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import average_losses_across_data_parallel_group, print_params_min_max_norm
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module, ModelType
from megatron.model import Float16Module
from megatron.core.enums import ModelType
def process_batch(batch):
"""Process batch and produce inputs for the model."""
......
import torch
from tests.test_utilities import Utils
import megatron.core.pipeline_parallel.schedules as schedule
from pytest_mock import mocker
import pytest
rank = Utils.rank
def test_get_forward_backward_func():
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)
assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining)
Utils.destroy_model_parallel()
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_without_interleaving)
Utils.destroy_model_parallel()
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=2)
assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving)
Utils.destroy_model_parallel()
def test_deallocate_output_tensor():
out = torch.tensor([[1, 2, 3], [4, 5, 6]])
schedule.deallocate_output_tensor(out)
assert(out.nelement() == 1)
def test_forward_backward_func_without_pipeline_parallel(mocker):
from megatron.core.pipeline_parallel import get_forward_backward_func
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)
def forward_step_func(data_iterator, model):
import os
rank = int(os.environ['LOCAL_RANK'])
dummy_data = torch.ones(1,4)
def loss_func(output_tensor):
return rank, {'loss_reduced':rank}
return model(dummy_data), loss_func
model = torch.nn.Linear(4,1)
model.model_type = 'unit-test'
def set_input_tensor(input_tensor):
return None
model.set_input_tensor = set_input_tensor
forward_backward_func = get_forward_backward_func()
assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining)
mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2)
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=None,
model=[model],
num_microbatches=4,
forward_only=False)
loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
for i,j in zip(losses_reduced, loss_reduced_expected):
print(losses_reduced)
assert(i['loss_reduced'] == j['loss_reduced'])
Utils.destroy_model_parallel()
def test_forward_backward_func_with_pipeline_parallel(mocker):
from megatron.core.pipeline_parallel import get_forward_backward_func
Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=4)
def forward_step_func(data_iterator, model):
import os
rank = int(os.environ['LOCAL_RANK'])
def loss_func(output_tensor):
return rank, {'loss_reduced':rank}
return torch.rand(512,8,256).cuda(), loss_func
model = torch.nn.Linear(4,1)
model.model_type = 'unit-test'
def set_input_tensor(input_tensor):
return None
model.set_input_tensor = set_input_tensor
forward_backward_func = get_forward_backward_func()
assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_without_interleaving)
sequence_length = 512
micro_batch_size = 8
hidden_size = 256
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=None,
dtype=torch.float32,
model=[model],
num_microbatches= micro_batch_size,
tensor_shape=[sequence_length, micro_batch_size, hidden_size],
decoder_seq_length=sequence_length,
sequence_parallel=False,
forward_only=True)
loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
for i,j in zip(losses_reduced, loss_reduced_expected):
print(losses_reduced)
assert(i['loss_reduced'] == j['loss_reduced'])
Utils.destroy_model_parallel()
"""
def test_forward_backward_func_with_interleaving(mocker):
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.enums import ModelType
Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=2)
def forward_step_func(data_iterator, model):
import os
rank = int(os.environ['LOCAL_RANK'])
def loss_func(output_tensor):
return rank, {'loss_reduced':rank}
return torch.rand(512,8,256).cuda(), loss_func
model = torch.nn.Linear(4,1)
def set_input_tensor(input_tensor):
return None
model.set_input_tensor = set_input_tensor
forward_backward_func = get_forward_backward_func()
assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving)
sequence_length = 512
micro_batch_size = 8
hidden_size = 256
mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2)
with pytest.raises(RuntimeError):
model.model_type = ModelType.encoder_and_decoder
forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=range(0,100),
dtype=torch.float32,
model=[model, model],
num_microbatches= micro_batch_size,
tensor_shape=[sequence_length, micro_batch_size, hidden_size],
decoder_seq_length=sequence_length,
sequence_parallel=False,
forward_only=True)
with pytest.raises(RuntimeError):
model.model_type = ModelType.encoder_or_decoder
forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=range(0,100),
dtype=torch.float32,
model=[model, model],
num_microbatches= micro_batch_size,
tensor_shape=[sequence_length, micro_batch_size, hidden_size],
decoder_seq_length=256,
sequence_parallel=False,
forward_only=True)
with pytest.raises(RuntimeError):
model.model_type = ModelType.encoder_or_decoder
forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=range(0,100),
dtype=torch.float32,
model=[model, model],
num_microbatches= 7,
tensor_shape=[sequence_length, micro_batch_size, hidden_size],
decoder_seq_length=512,
sequence_parallel=False,
forward_only=True)
model.model_type = ModelType.encoder_or_decoder
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=range(0,100),
dtype=torch.float32,
model=[model, model],
num_microbatches= micro_batch_size,
tensor_shape=[sequence_length, micro_batch_size, hidden_size],
decoder_seq_length=sequence_length,
sequence_parallel=True,
forward_only=True)
loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
for i,j in zip(losses_reduced, loss_reduced_expected):
print(losses_reduced)
assert(i['loss_reduced'] == j['loss_reduced'])
Utils.destroy_model_parallel()
"""
\ No newline at end of file
......@@ -11,7 +11,8 @@ from tqdm import tqdm
from megatron import get_args, get_tokenizer, print_rank_0
from megatron import core
from megatron.model import BertModel, ModelType
from megatron.core.enums import ModelType
from megatron.model import BertModel
from megatron.schedules import get_forward_backward_func
from megatron.training import setup_model_and_optimizer
......
......@@ -29,8 +29,9 @@ def _load_checkpoint(queue, args):
from megatron.arguments import parse_args, validate_args
from megatron.global_vars import set_args, set_global_variables
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron.model import ModelType, module
from megatron.model import module
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron import fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
......
......@@ -32,7 +32,7 @@ def save_checkpoint(queue, args):
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import save_checkpoint
from megatron.global_vars import set_global_variables, get_args
from megatron.model import ModelType
from megatron.core.enums import ModelType
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
from megatron import fused_kernels
from megatron.core import mpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment