Commit 42c40715 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'core' into 'main'

Move most of mpu functionality into a new "Megatron core"

See merge request ADLR/megatron-lm!462
parents fdc801ed d3a416cd
from setuptools import setup, find_packages
setup(
name="megatron.core",
version="0.1",
description="Core components of Megatron.",
packages=find_packages(
include=("megatron.core")
)
)
......@@ -10,7 +10,7 @@ import torch
from megatron import get_args
from megatron import print_rank_last, is_last_rank
from megatron import mpu
from megatron.core import mpu
from megatron.schedules import get_forward_backward_func
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
......
......@@ -9,7 +9,7 @@ import torch
from megatron import get_args, get_num_microbatches
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.core import mpu
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import ModelType
......
......@@ -5,7 +5,6 @@
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
......
......@@ -6,10 +6,10 @@ import json
import torch
import requests
from nltk import word_tokenize
from megatron import mpu
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.core import mpu
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.checkpointing import load_checkpoint
......
......@@ -10,7 +10,7 @@ import torch.nn.functional as F
from torch.utils.data import DataLoader
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.core import mpu
from megatron.utils import average_losses_across_data_parallel_group
from tasks.finetune_utils import build_data_loader
......
......@@ -9,8 +9,8 @@ import math
import torch
import torch.nn.functional as F
from megatron import get_args, get_timers, get_tokenizer
from megatron import mpu, print_rank_0
from megatron import get_args, get_timers, get_tokenizer, print_rank_0
from megatron.core import mpu
from megatron.indexer import IndexBuilder
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -13,7 +13,7 @@ import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, BatchSampler
from megatron import print_rank_0, get_args, get_tokenizer, mpu
from megatron import print_rank_0, get_args, get_tokenizer
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_nq_dataset(qa_data, split):
......
......@@ -5,7 +5,6 @@
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model.multiple_choice import MultipleChoice
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
......
......@@ -9,7 +9,7 @@ import torch
from megatron import get_args
from megatron import print_rank_0, print_rank_last
from megatron import mpu
from megatron.core import mpu
from megatron.schedules import get_forward_backward_func
from tasks.vision.finetune_utils import build_data_loader
from tasks.vision.finetune_utils import process_batch
......
......@@ -7,7 +7,8 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu, utils
from megatron import utils
from megatron.core import mpu
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
......
......@@ -7,7 +7,8 @@ import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import mpu, print_rank_0, print_rank_last
from megatron import print_rank_0, print_rank_last
from megatron.core import mpu
from tasks.vision.finetune_utils import finetune
from tasks.vision.finetune_utils import build_data_loader
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -6,7 +6,8 @@ import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import mpu, print_rank_0, print_rank_last
from megatron import print_rank_0, print_rank_last
from megatron.core import mpu
from tasks.vision.finetune_utils import finetune
from tasks.vision.finetune_utils import build_data_loader
from megatron.utils import average_losses_across_data_parallel_group
......
......@@ -9,7 +9,7 @@ import torch
from megatron import get_args
from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer
from megatron import mpu
from megatron.core import mpu
from megatron.checkpointing import load_checkpoint
from megatron.model import GPTModel
from megatron.training import get_model
......@@ -93,7 +93,7 @@ def forward_step(batch, model, eval_metric):
if mpu.is_pipeline_last_stage():
# For loss, return the unreduced loss.
if eval_metric == 'loss':
losses = mpu.vocab_parallel_cross_entropy(
losses = mpu.tensor_parallel.vocab_parallel_cross_entropy(
output.contiguous().float(), labels.contiguous())
loss = torch.sum(
losses.view(-1) * loss_mask.contiguous().view(-1).float())
......
from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
import torch
from tests.test_utilities import Utils
import numpy as np
def test_vocab_parallel_cross_entropy():
Utils.initialize_model_parallel(4,2)
vocab_parallel_logits = torch.range(0,7).repeat(16,4).cuda()
target = torch.arange(0,32,2).cuda()
output = vocab_parallel_cross_entropy(vocab_parallel_logits, target)
expected_output = torch.tensor([10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309,
10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309]).cuda()
assert(torch.equal(torch.round(expected_output), torch.round(output)))
Utils.destroy_model_parallel()
\ No newline at end of file
from megatron.core.tensor_parallel.data import broadcast_data
import torch
from tests.test_utilities import Utils
def test_broadcast_data():
Utils.initialize_model_parallel(2,4)
input_data = {
0 : torch.ones((8,8)).cuda() * 0.0,
1 : torch.ones((8,8)).cuda() * 1.0,
2 : torch.ones((8,8)).cuda() * 2.0,
3 : torch.ones((8,8)).cuda() * 3.0,
4 : torch.ones((8,8)).cuda() * 4.0,
5 : torch.ones((8,8)).cuda() * 5.0,
6 : torch.ones((8,8)).cuda() * 6.0,
7 : torch.ones((8,8)).cuda() * 7.0
}
dtype = torch.float32
actual_output = broadcast_data([0,1],input_data, dtype)
assert(torch.equal(actual_output[0], input_data[0]))
assert(torch.equal(actual_output[1], input_data[1]))
Utils.destroy_model_parallel()
\ No newline at end of file
from megatron.core.tensor_parallel import mappings
from tests.test_utilities import Utils
import torch
def test_CopyToModelParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.ones((1)).cuda()*Utils.rank
output_data = mappings._CopyToModelParallelRegion.backward(None, input_data)
result = torch.ones(1).cuda()
result = result * 22 if Utils.rank >= 4 else result * 6
assert(torch.equal(output_data, result))
assert(torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data)))
assert(torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data)))
Utils.destroy_model_parallel()
def test_ReduceFromModelParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.ones((1)).cuda()*Utils.rank
output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data)
result = torch.ones(1).cuda()
result = result * 22 if Utils.rank >= 4 else result * 6
assert(torch.equal(output_data, result))
input_data = torch.ones((1)).cuda()*Utils.rank
assert(torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result))
assert(torch.equal(input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data)))
Utils.destroy_model_parallel()
def test_ScatterToModelParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.rand((8,4)).cuda()
output_data = mappings.scatter_to_tensor_model_parallel_region(input_data)
req_dim = int(Utils.rank%(Utils.world_size/2))
assert(torch.equal(output_data, input_data[:,req_dim].reshape((8,1))))
output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data)
assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1))))
input_data = torch.ones(8).cuda() * Utils.rank
actual_output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
expected_output = torch.cat((
torch.ones(8)*0,
torch.ones(8)*1,
torch.ones(8)*2,
torch.ones(8)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(actual_output_data, expected_output))
Utils.destroy_model_parallel()
def test_GatherFromModelParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.rand((8,4)).cuda()
req_dim = int(Utils.rank%(Utils.world_size/2))
output_data = mappings._GatherFromModelParallelRegion.backward(None, input_data)
assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1))))
input_data = torch.ones(8).cuda() * Utils.rank
actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data)
expected_output = torch.cat((
torch.ones(8)*0,
torch.ones(8)*1,
torch.ones(8)*2,
torch.ones(8)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(actual_output_data, expected_output))
assert(torch.equal(mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output))
Utils.destroy_model_parallel()
def test_ScatterToSequenceParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.rand((8,4)).cuda()
req_dim = int(Utils.rank%(Utils.world_size/2))*2
output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data)
assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :]))
output_data = mappings.scatter_to_sequence_parallel_region(input_data)
assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :]))
input_data = torch.ones(4).cuda() * Utils.rank
output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
expected_output = torch.concat((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(output_data, expected_output))
Utils.destroy_model_parallel()
def test_GatherFromSequenceParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.ones(4).cuda() * Utils.rank
output_data = mappings.gather_from_sequence_parallel_region(input_data)
expected_output = torch.concat((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(output_data, expected_output))
assert(torch.equal(mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output))
input_data = torch.vstack((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
class Ctx:
tensor_parallel_output_grad = True
output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data)
expected_output = torch.ones((1,4)).cuda() * 4 * int(Utils.rank % 4)
assert(torch.equal(output_data[0], expected_output))
Utils.destroy_model_parallel()
def test_ReduceScatterToSequenceParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.vstack((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data)
expected_output = torch.ones(4).cuda() * 4 * int(Utils.rank % 4)
assert(torch.equal(output_data[0], expected_output))
assert(torch.equal(mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data) , expected_output.reshape((1,4))))
input_data = torch.ones(4).cuda() * Utils.rank
output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(None,input_data)
expected_output = torch.concat((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(output_data, expected_output))
Utils.destroy_model_parallel()
from megatron.core.tensor_parallel.random import CudaRNGStatesTracker
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.tensor_parallel.random import _CUDA_RNG_STATE_TRACKER
from megatron.core.tensor_parallel.random import checkpoint
from tests.test_utilities import Utils
import pytest
import torch
def test_cuda_rng_states_tracker():
rng_tracker = CudaRNGStatesTracker()
rng_tracker.set_states({"state1":1234})
assert(rng_tracker.get_states()["state1"] == 1234)
rng_tracker.reset()
assert(rng_tracker.get_states() == {})
seed = 1111
rng_tracker.add("state2",seed)
with pytest.raises(Exception):
assert(rng_tracker.add("state3",seed))
with pytest.raises(Exception):
assert(rng_tracker.add("state2",111))
assert(rng_tracker.get_states()['state2'] is not None)
with pytest.raises(Exception):
assert()
rng_tracker.fork("state2")
torch.cuda.manual_seed(seed)
rng_state = torch.cuda.get_rng_state()
assert torch.equal(rng_tracker.get_states()['state2'], rng_state)
def test_model_parallel_cuda_manual_seed():
Utils.initialize_model_parallel(4,2)
model_parallel_cuda_manual_seed(0)
assert(_CUDA_RNG_STATE_TRACKER.get_states()['model-parallel-rng'] is not None)
Utils.destroy_model_parallel()
def test_checkpoint():
def test_forward(*input):
return input[0]+input[1]
assert(torch.equal(torch.ones(16)*3,checkpoint(test_forward, None, torch.ones(16), torch.ones(16)*2)))
Utils.initialize_model_parallel()
input1 = torch.ones((4,4))
checkpoint(test_forward, True, input1, torch.ones((4,4))*2)
assert(torch.equal(torch.ones(input1.numel()).cuda(), input1))
Utils.destroy_model_parallel()
\ No newline at end of file
import torch
import megatron.core.tensor_parallel.utils as util
import megatron.core.parallel_state as ps
from tests.test_utilities import Utils
rank = Utils.rank
def test_split_tensor_along_last_dim():
input_tensor = torch.rand((3,4))
torch.equal(input_tensor[0:2,0:2], util.split_tensor_along_last_dim(input_tensor,2)[0])
torch.equal(input_tensor[2:,2:], util.split_tensor_along_last_dim(input_tensor,2)[1])
def test_split_tensor_into_1d_equal_chunks():
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
input_tensor = torch.rand((3,4))
output_tensor = util.split_tensor_into_1d_equal_chunks(input_tensor)
if rank % 2 == 0 :
start = 0
end = int(input_tensor.numel()/2)
else :
start = int(input_tensor.numel()/2)
end = input_tensor.numel()
assert torch.equal(output_tensor, input_tensor.flatten()[start:end])
Utils.destroy_model_parallel()
def test_gather_split_1d_tensor():
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
input_tensor = torch.ones((2,4)).cuda() * rank
actual_output_tensor = util.gather_split_1d_tensor(input_tensor)
if rank %2 == 0:
expected_output_tensor = torch.concat((input_tensor.flatten(), input_tensor.flatten() + 1))
else :
expected_output_tensor = torch.concat((input_tensor.flatten() - 1, input_tensor.flatten()))
assert(torch.equal(actual_output_tensor, expected_output_tensor))
Utils.destroy_model_parallel()
def test_vocab():
global_vocab_size = 1600
per_partition_vocab_size = 1600 / Utils.world_size
assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_per_partition_vocab_size(global_vocab_size // Utils.world_size, rank, Utils.world_size)))
assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_global_vocab_size(global_vocab_size, rank, Utils.world_size)))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment