Commit 42c40715 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'core' into 'main'

Move most of mpu functionality into a new "Megatron core"

See merge request ADLR/megatron-lm!462
parents fdc801ed d3a416cd
import torch
import megatron.core.parallel_state as ps
import pytest
from tests.test_utilities import Utils
import os
rank = Utils.rank
world_size = Utils.world_size
def test_initialize__and_destroy_model_parallel():
with pytest.raises(AssertionError):
assert(ps.initialize_model_parallel())
Utils.initialize_distributed()
with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))
with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))
with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))
with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
assert(ps.model_parallel_is_initialized())
assert(ps.get_model_parallel_group() is not None)
assert(ps.get_tensor_model_parallel_group() is not None)
assert(ps.get_pipeline_model_parallel_group() is not None)
assert(ps.get_data_parallel_group() is not None)
Utils.destroy_model_parallel()
assert(ps._MODEL_PARALLEL_GROUP is None)
def test_pipeline_parallel_initializations():
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
assert(ps.get_pipeline_model_parallel_first_rank() == rank % 2 )
assert(ps.get_data_parallel_src_rank() == rank)
assert(ps.get_pipeline_model_parallel_next_rank() == ((rank + 2) % world_size))
assert(ps.get_pipeline_model_parallel_prev_rank() == ((rank - 2) % world_size))
Utils.destroy_model_parallel()
def test_data_parallel_initializations():
Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_data_parallel_src_rank() == rank)
assert(ps.get_data_parallel_world_size() == 1)
assert(ps.get_data_parallel_rank() == 0)
Utils.destroy_model_parallel()
def test_tensor_model_parellel_world_size():
Utils.initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_world_size() == world_size)
ps.set_tensor_model_parallel_world_size(None)
assert(ps.get_tensor_model_parallel_world_size() == world_size)
Utils.destroy_model_parallel()
def test_pipeline_model_parallel_world_size():
Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_pipeline_model_parallel_world_size() == world_size)
ps.set_pipeline_model_parallel_world_size(None)
assert(ps.get_pipeline_model_parallel_world_size() == world_size)
Utils.destroy_model_parallel()
def test_tensor_model_parallel_rank():
Utils.initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_rank() == rank)
ps.set_tensor_model_parallel_rank(None)
assert(ps.get_tensor_model_parallel_rank() == rank)
Utils.destroy_model_parallel()
def test_pipeline_model_parallel_rank():
Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_pipeline_model_parallel_rank() == rank)
ps.set_pipeline_model_parallel_rank(None)
assert(ps.get_pipeline_model_parallel_rank() == rank)
Utils.destroy_model_parallel()
def test_is_pipeline_first_stage():
Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0))
assert(ps.is_pipeline_first_stage() == (rank == 0))
Utils.destroy_model_parallel()
def test_is_pipeline_last_stage():
Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size-1))
assert(ps.is_pipeline_last_stage() == (rank == world_size-1))
Utils.destroy_model_parallel()
def test_virtual_pipeline_model_parallel_rank():
Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
ps.set_virtual_pipeline_model_parallel_rank(rank)
assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
Utils.destroy_model_parallel()
def test_get_tensor_model_parallel_src_rank():
Utils.initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))
Utils.destroy_model_parallel()
\ No newline at end of file
import os
import torch
import megatron.core.parallel_state as ps
class Utils:
world_size = torch.cuda.device_count()
rank = int(os.environ['LOCAL_RANK'])
@staticmethod
def initialize_distributed():
print(f'Initializing torch.distributed with rank: {Utils.rank}, world_size: {Utils.world_size}')
torch.cuda.set_device(Utils.rank % torch.cuda.device_count())
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(backend='nccl', world_size=Utils.world_size, rank=Utils.rank, init_method=init_method)
@staticmethod
def destroy_model_parallel():
ps.destroy_model_parallel()
torch.distributed.barrier()
@staticmethod
def initialize_model_parallel(tensor_model_parallel_size = 1, pipeline_model_parallel_size = 1, virtual_pipeline_model_parallel_size = None, pipeline_model_parallel_split_rank = None):
ps.destroy_model_parallel()
if not torch.distributed.is_initialized():
Utils.initialize_distributed()
ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank)
\ No newline at end of file
import pytest
import torch
import megatron.core.utils as util
import numpy as np
def test_divide_properly():
assert util.divide(4,2) == 2
def test_divide_improperly():
with pytest.raises(AssertionError):
util.divide(4,5)
def test_global_memory_buffer():
global_memory_buffer = util.GlobalMemoryBuffer()
obtained_tensor = global_memory_buffer.get_tensor((3,2), torch.float32, "test_tensor")
expected_tensor = torch.empty((3,2), dtype=torch.float32, device=torch.cuda.current_device())
assert torch.equal(obtained_tensor, expected_tensor)
def test_make_viewless_tensor():
inp = torch.rand((3,4))
assert(torch.equal(inp, util.make_viewless_tensor(inp, True, True)))
assert(torch.equal(inp, util.make_viewless_tensor(inp, True, False)))
def test_safely_set_viewless_tensor_data():
tensor = torch.zeros((3,4))
new_data_tensor = torch.tensor(np.random.rand(3,4))
util.safely_set_viewless_tensor_data(tensor, new_data_tensor)
assert(torch.equal(tensor, new_data_tensor))
def test_assert_viewless_tensor():
tensor = torch.rand((3,4))
assert(torch.equal(util.assert_viewless_tensor(tensor), tensor))
input_tensor_list=[tensor,tensor,tensor]
output_tensor_list = util.assert_viewless_tensor(input_tensor_list)
for inp,out in zip(input_tensor_list, output_tensor_list):
assert(torch.equal(inp,out))
......@@ -30,7 +30,8 @@ def _load_checkpoint(queue, args):
from megatron.global_vars import set_args, set_global_variables
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron.model import ModelType, module
from megatron import mpu, fused_kernels
from megatron.core import mpu
from megatron import fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
queue.put("exit")
......@@ -99,7 +100,7 @@ def _load_checkpoint(queue, args):
nonlocal consumed_valid_samples
models = []
for rank in range(count):
mpu.initialize.set_tensor_model_parallel_rank(rank)
mpu.parallel_state.set_tensor_model_parallel_rank(rank)
model_ = [model_provider(pre_process, post_process).to(dtype)]
margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0
......@@ -123,8 +124,8 @@ def _load_checkpoint(queue, args):
exit(1)
set_global_variables(margs)
mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
mpu.parallel_state.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.parallel_state.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
fused_kernels.load(margs)
# Get true (non-padded) vocab size
......@@ -162,7 +163,7 @@ def _load_checkpoint(queue, args):
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
# Get first pipe stage
mpu.initialize.set_pipeline_model_parallel_rank(0)
mpu.parallel_state.set_pipeline_model_parallel_rank(0)
post_process = pp_size == 1
models = get_models(tp_size, md.params_dtype, True, post_process)
......@@ -188,7 +189,7 @@ def _load_checkpoint(queue, args):
total_layer_num = 0
for pp_rank in range(pp_size):
if pp_rank > 0:
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
mpu.parallel_state.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == pp_size - 1
models = get_models(tp_size, md.params_dtype, False, post_process)
for layer_num in range(len(models[0].language_model.encoder.layers)):
......
......@@ -34,7 +34,8 @@ def save_checkpoint(queue, args):
from megatron.global_vars import set_global_variables, get_args
from megatron.model import ModelType
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
from megatron import mpu, fused_kernels
from megatron import fused_kernels
from megatron.core import mpu
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
exit(1)
......@@ -152,10 +153,10 @@ def save_checkpoint(queue, args):
return models
# fake initializing distributed
mpu.initialize.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_rank(0)
mpu.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
mpu.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
mpu.set_tensor_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(0)
fused_kernels.load(margs)
# Embeddings
......@@ -197,7 +198,7 @@ def save_checkpoint(queue, args):
out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
# Make models for first pipeline stage and fill in embeddings
mpu.initialize.set_pipeline_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(0)
post_process = args.target_pipeline_parallel_size == 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
for tp_rank, model in enumerate(models):
......@@ -211,7 +212,7 @@ def save_checkpoint(queue, args):
for pp_rank in range(args.target_pipeline_parallel_size):
# For later pipeline parallel ranks, make the new models
if pp_rank > 0:
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
mpu.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == args.target_pipeline_parallel_size - 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
......@@ -317,6 +318,6 @@ def save_checkpoint(queue, args):
print("ERROR: got some more data but was expecting to be done")
for tp_rank in range(args.target_tensor_parallel_size):
mpu.initialize.set_tensor_model_parallel_rank(tp_rank)
mpu.set_tensor_model_parallel_rank(tp_rank)
save_checkpoint(md.iteration, [models[tp_rank]], None, None)
print("Done!")
......@@ -8,7 +8,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
import socket
from megatron import get_args
from megatron import print_rank_0
from megatron import mpu
from megatron.core import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment