Commit 3c92fa93 authored by Jared Casper's avatar Jared Casper
Browse files

Move pipeline parallel functionality into core with associated changes.

parent 0b44909c
...@@ -6,13 +6,13 @@ import torch ...@@ -6,13 +6,13 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from functools import partial
from megatron import get_args, get_timers, print_rank_0, print_rank_last from megatron import get_args, get_timers, print_rank_0, print_rank_last
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.inpainting import VitInpaintingModel from megatron.model.vision.inpainting import VitInpaintingModel
from megatron.model.vision.inpainting import MitInpaintingModel from megatron.model.vision.inpainting import MitInpaintingModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
from tasks.vision.metrics import SSIM, PSNR from tasks.vision.metrics import SSIM, PSNR
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
......
...@@ -10,9 +10,9 @@ from megatron import get_args, get_num_microbatches ...@@ -10,9 +10,9 @@ from megatron import get_args, get_num_microbatches
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron.core import mpu from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.model import ModelType
from megatron.training import evaluate_and_print_results from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer from megatron.training import setup_model_and_optimizer
from megatron.training import train_step from megatron.training import train_step
......
...@@ -19,8 +19,8 @@ from megatron.utils import check_adlr_autoresume_termination ...@@ -19,8 +19,8 @@ from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import average_losses_across_data_parallel_group, print_params_min_max_norm from megatron.utils import average_losses_across_data_parallel_group, print_params_min_max_norm
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module, ModelType from megatron.model import Float16Module
from megatron.core.enums import ModelType
def process_batch(batch): def process_batch(batch):
"""Process batch and produce inputs for the model.""" """Process batch and produce inputs for the model."""
......
import torch
from tests.test_utilities import Utils
import megatron.core.pipeline_parallel.schedules as schedule
from pytest_mock import mocker
import pytest
rank = Utils.rank
def test_get_forward_backward_func():
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)
assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining)
Utils.destroy_model_parallel()
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_without_interleaving)
Utils.destroy_model_parallel()
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=2)
assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving)
Utils.destroy_model_parallel()
def test_deallocate_output_tensor():
out = torch.tensor([[1, 2, 3], [4, 5, 6]])
schedule.deallocate_output_tensor(out)
assert(out.nelement() == 1)
def test_forward_backward_func_without_pipeline_parallel(mocker):
from megatron.core.pipeline_parallel import get_forward_backward_func
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)
def forward_step_func(data_iterator, model):
import os
rank = int(os.environ['LOCAL_RANK'])
dummy_data = torch.ones(1,4)
def loss_func(output_tensor):
return rank, {'loss_reduced':rank}
return model(dummy_data), loss_func
model = torch.nn.Linear(4,1)
model.model_type = 'unit-test'
def set_input_tensor(input_tensor):
return None
model.set_input_tensor = set_input_tensor
forward_backward_func = get_forward_backward_func()
assert(schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining)
mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2)
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=None,
model=[model],
num_microbatches=4,
forward_only=False)
loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
for i,j in zip(losses_reduced, loss_reduced_expected):
print(losses_reduced)
assert(i['loss_reduced'] == j['loss_reduced'])
Utils.destroy_model_parallel()
def test_forward_backward_func_with_pipeline_parallel(mocker):
from megatron.core.pipeline_parallel import get_forward_backward_func
Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=4)
def forward_step_func(data_iterator, model):
import os
rank = int(os.environ['LOCAL_RANK'])
def loss_func(output_tensor):
return rank, {'loss_reduced':rank}
return torch.rand(512,8,256).cuda(), loss_func
model = torch.nn.Linear(4,1)
model.model_type = 'unit-test'
def set_input_tensor(input_tensor):
return None
model.set_input_tensor = set_input_tensor
forward_backward_func = get_forward_backward_func()
assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_without_interleaving)
sequence_length = 512
micro_batch_size = 8
hidden_size = 256
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=None,
dtype=torch.float32,
model=[model],
num_microbatches= micro_batch_size,
tensor_shape=[sequence_length, micro_batch_size, hidden_size],
decoder_seq_length=sequence_length,
sequence_parallel=False,
forward_only=True)
loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
for i,j in zip(losses_reduced, loss_reduced_expected):
print(losses_reduced)
assert(i['loss_reduced'] == j['loss_reduced'])
Utils.destroy_model_parallel()
"""
def test_forward_backward_func_with_interleaving(mocker):
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.enums import ModelType
Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=4, virtual_pipeline_model_parallel_size=2)
def forward_step_func(data_iterator, model):
import os
rank = int(os.environ['LOCAL_RANK'])
def loss_func(output_tensor):
return rank, {'loss_reduced':rank}
return torch.rand(512,8,256).cuda(), loss_func
model = torch.nn.Linear(4,1)
def set_input_tensor(input_tensor):
return None
model.set_input_tensor = set_input_tensor
forward_backward_func = get_forward_backward_func()
assert(schedule.get_forward_backward_func() == schedule.forward_backward_pipelining_with_interleaving)
sequence_length = 512
micro_batch_size = 8
hidden_size = 256
mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2)
with pytest.raises(RuntimeError):
model.model_type = ModelType.encoder_and_decoder
forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=range(0,100),
dtype=torch.float32,
model=[model, model],
num_microbatches= micro_batch_size,
tensor_shape=[sequence_length, micro_batch_size, hidden_size],
decoder_seq_length=sequence_length,
sequence_parallel=False,
forward_only=True)
with pytest.raises(RuntimeError):
model.model_type = ModelType.encoder_or_decoder
forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=range(0,100),
dtype=torch.float32,
model=[model, model],
num_microbatches= micro_batch_size,
tensor_shape=[sequence_length, micro_batch_size, hidden_size],
decoder_seq_length=256,
sequence_parallel=False,
forward_only=True)
with pytest.raises(RuntimeError):
model.model_type = ModelType.encoder_or_decoder
forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=range(0,100),
dtype=torch.float32,
model=[model, model],
num_microbatches= 7,
tensor_shape=[sequence_length, micro_batch_size, hidden_size],
decoder_seq_length=512,
sequence_parallel=False,
forward_only=True)
model.model_type = ModelType.encoder_or_decoder
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=range(0,100),
dtype=torch.float32,
model=[model, model],
num_microbatches= micro_batch_size,
tensor_shape=[sequence_length, micro_batch_size, hidden_size],
decoder_seq_length=sequence_length,
sequence_parallel=True,
forward_only=True)
loss_reduced_expected = [{'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}, {'loss_reduced': rank}]
for i,j in zip(losses_reduced, loss_reduced_expected):
print(losses_reduced)
assert(i['loss_reduced'] == j['loss_reduced'])
Utils.destroy_model_parallel()
"""
\ No newline at end of file
...@@ -11,7 +11,8 @@ from tqdm import tqdm ...@@ -11,7 +11,8 @@ from tqdm import tqdm
from megatron import get_args, get_tokenizer, print_rank_0 from megatron import get_args, get_tokenizer, print_rank_0
from megatron import core from megatron import core
from megatron.model import BertModel, ModelType from megatron.core.enums import ModelType
from megatron.model import BertModel
from megatron.schedules import get_forward_backward_func from megatron.schedules import get_forward_backward_func
from megatron.training import setup_model_and_optimizer from megatron.training import setup_model_and_optimizer
......
...@@ -29,8 +29,9 @@ def _load_checkpoint(queue, args): ...@@ -29,8 +29,9 @@ def _load_checkpoint(queue, args):
from megatron.arguments import parse_args, validate_args from megatron.arguments import parse_args, validate_args
from megatron.global_vars import set_args, set_global_variables from megatron.global_vars import set_args, set_global_variables
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron.model import ModelType, module from megatron.model import module
from megatron.core import mpu from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron import fused_kernels from megatron import fused_kernels
except ModuleNotFoundError: except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
......
...@@ -32,7 +32,7 @@ def save_checkpoint(queue, args): ...@@ -32,7 +32,7 @@ def save_checkpoint(queue, args):
from megatron.arguments import (parse_args, validate_args) from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.global_vars import set_global_variables, get_args from megatron.global_vars import set_global_variables, get_args
from megatron.model import ModelType from megatron.core.enums import ModelType
from megatron.tokenizer.tokenizer import _vocab_size_with_padding from megatron.tokenizer.tokenizer import _vocab_size_with_padding
from megatron import fused_kernels from megatron import fused_kernels
from megatron.core import mpu from megatron.core import mpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment