Commit 8806ba73 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'properTest' into 'core'

Adding proper test cases

See merge request ADLR/megatron-lm!460
parents f8614670 8b94a160
File deleted
[html] [html]
directory = coverage directory = coverage
\ No newline at end of file
[run]
data_file = .coverage_$LOCAL_RANK
__pycache__ __pycache__
*.so *.so
build build
.coverage_*
*.egg-info *.egg-info
...@@ -4,8 +4,7 @@ test: ...@@ -4,8 +4,7 @@ test:
tags: tags:
- docker_gpu_enabled - docker_gpu_enabled
script: script:
- nvidia-smi - torchrun --nproc_per_node=8 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
- torchrun --nproc_per_node=2 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/' coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
artifacts: artifacts:
paths: paths:
......
...@@ -22,6 +22,8 @@ from .utils import ( ...@@ -22,6 +22,8 @@ from .utils import (
gather_split_1d_tensor, gather_split_1d_tensor,
) )
from megatron.core.utils import safely_set_viewless_tensor_data
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
......
from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
import torch
from tests.test_utilities import Utils
import numpy as np
def test_vocab_parallel_cross_entropy():
Utils.initialize_model_parallel(4,2)
vocab_parallel_logits = torch.range(0,7).repeat(16,4).cuda()
target = torch.arange(0,32,2).cuda()
output = vocab_parallel_cross_entropy(vocab_parallel_logits, target)
expected_output = torch.tensor([10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309,
10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309]).cuda()
assert(torch.equal(torch.round(expected_output), torch.round(output)))
Utils.destroy_model_parallel()
\ No newline at end of file
from megatron.core.tensor_parallel.data import broadcast_data
import torch
from tests.test_utilities import Utils
def test_broadcast_data():
Utils.initialize_model_parallel(2,4)
input_data = {
0 : torch.ones((8,8)).cuda() * 0.0,
1 : torch.ones((8,8)).cuda() * 1.0,
2 : torch.ones((8,8)).cuda() * 2.0,
3 : torch.ones((8,8)).cuda() * 3.0,
4 : torch.ones((8,8)).cuda() * 4.0,
5 : torch.ones((8,8)).cuda() * 5.0,
6 : torch.ones((8,8)).cuda() * 6.0,
7 : torch.ones((8,8)).cuda() * 7.0
}
dtype = torch.float32
actual_output = broadcast_data([0,1],input_data, dtype)
assert(torch.equal(actual_output[0], input_data[0]))
assert(torch.equal(actual_output[1], input_data[1]))
Utils.destroy_model_parallel()
\ No newline at end of file
from megatron.core.tensor_parallel import mappings
from tests.test_utilities import Utils
import torch
def test_CopyToModelParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.ones((1)).cuda()*Utils.rank
output_data = mappings._CopyToModelParallelRegion.backward(None, input_data)
result = torch.ones(1).cuda()
result = result * 22 if Utils.rank >= 4 else result * 6
assert(torch.equal(output_data, result))
assert(torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data)))
assert(torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data)))
Utils.destroy_model_parallel()
def test_ReduceFromModelParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.ones((1)).cuda()*Utils.rank
output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data)
result = torch.ones(1).cuda()
result = result * 22 if Utils.rank >= 4 else result * 6
assert(torch.equal(output_data, result))
input_data = torch.ones((1)).cuda()*Utils.rank
assert(torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result))
assert(torch.equal(input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data)))
Utils.destroy_model_parallel()
def test_ScatterToModelParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.rand((8,4)).cuda()
output_data = mappings.scatter_to_tensor_model_parallel_region(input_data)
req_dim = int(Utils.rank%(Utils.world_size/2))
assert(torch.equal(output_data, input_data[:,req_dim].reshape((8,1))))
output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data)
assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1))))
input_data = torch.ones(8).cuda() * Utils.rank
actual_output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
expected_output = torch.cat((
torch.ones(8)*0,
torch.ones(8)*1,
torch.ones(8)*2,
torch.ones(8)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(actual_output_data, expected_output))
Utils.destroy_model_parallel()
def test_GatherFromModelParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.rand((8,4)).cuda()
req_dim = int(Utils.rank%(Utils.world_size/2))
output_data = mappings._GatherFromModelParallelRegion.backward(None, input_data)
assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1))))
input_data = torch.ones(8).cuda() * Utils.rank
actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data)
expected_output = torch.cat((
torch.ones(8)*0,
torch.ones(8)*1,
torch.ones(8)*2,
torch.ones(8)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(actual_output_data, expected_output))
assert(torch.equal(mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output))
Utils.destroy_model_parallel()
def test_ScatterToSequenceParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.rand((8,4)).cuda()
req_dim = int(Utils.rank%(Utils.world_size/2))*2
output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data)
assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :]))
output_data = mappings.scatter_to_sequence_parallel_region(input_data)
assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :]))
input_data = torch.ones(4).cuda() * Utils.rank
output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
expected_output = torch.concat((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(output_data, expected_output))
Utils.destroy_model_parallel()
def test_GatherFromSequenceParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.ones(4).cuda() * Utils.rank
output_data = mappings.gather_from_sequence_parallel_region(input_data)
expected_output = torch.concat((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(output_data, expected_output))
assert(torch.equal(mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output))
input_data = torch.vstack((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
class Ctx:
tensor_parallel_output_grad = True
output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data)
expected_output = torch.ones((1,4)).cuda() * 4 * int(Utils.rank % 4)
assert(torch.equal(output_data[0], expected_output))
Utils.destroy_model_parallel()
def test_ReduceScatterToSequenceParallelRegion():
Utils.initialize_model_parallel(4,2)
input_data = torch.vstack((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data)
expected_output = torch.ones(4).cuda() * 4 * int(Utils.rank % 4)
assert(torch.equal(output_data[0], expected_output))
assert(torch.equal(mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data) , expected_output.reshape((1,4))))
input_data = torch.ones(4).cuda() * Utils.rank
output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(None,input_data)
expected_output = torch.concat((
torch.ones(4)*0,
torch.ones(4)*1,
torch.ones(4)*2,
torch.ones(4)*3)).cuda()
if (Utils.rank >= 4):
expected_output = expected_output + 4
assert(torch.equal(output_data, expected_output))
Utils.destroy_model_parallel()
from megatron.core.tensor_parallel.random import CudaRNGStatesTracker
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.tensor_parallel.random import _CUDA_RNG_STATE_TRACKER
from megatron.core.tensor_parallel.random import checkpoint
from tests.test_utilities import Utils
import pytest
import torch
def test_cuda_rng_states_tracker():
rng_tracker = CudaRNGStatesTracker()
rng_tracker.set_states({"state1":1234})
assert(rng_tracker.get_states()["state1"] == 1234)
rng_tracker.reset()
assert(rng_tracker.get_states() == {})
seed = 1111
rng_tracker.add("state2",seed)
with pytest.raises(Exception):
assert(rng_tracker.add("state3",seed))
with pytest.raises(Exception):
assert(rng_tracker.add("state2",111))
assert(rng_tracker.get_states()['state2'] is not None)
with pytest.raises(Exception):
assert()
rng_tracker.fork("state2")
torch.cuda.manual_seed(seed)
rng_state = torch.cuda.get_rng_state()
assert torch.equal(rng_tracker.get_states()['state2'], rng_state)
def test_model_parallel_cuda_manual_seed():
Utils.initialize_model_parallel(4,2)
model_parallel_cuda_manual_seed(0)
assert(_CUDA_RNG_STATE_TRACKER.get_states()['model-parallel-rng'] is not None)
Utils.destroy_model_parallel()
def test_checkpoint():
def test_forward(*input):
return input[0]+input[1]
assert(torch.equal(torch.ones(16)*3,checkpoint(test_forward, None, torch.ones(16), torch.ones(16)*2)))
Utils.initialize_model_parallel()
input1 = torch.ones((4,4))
checkpoint(test_forward, True, input1, torch.ones((4,4))*2)
assert(torch.equal(torch.ones(input1.numel()).cuda(), input1))
Utils.destroy_model_parallel()
\ No newline at end of file
import torch import torch
import megatron.core.tensor_parallel.utils as util import megatron.core.tensor_parallel.utils as util
import megatron.core.parallel_state as ps
from tests.test_utilities import Utils
rank = Utils.rank
def test_split_tensor_along_last_dim(): def test_split_tensor_along_last_dim():
input_tensor = torch.rand((3,4)) input_tensor = torch.rand((3,4))
torch.equal(input_tensor[0:2,0:2], util.split_tensor_along_last_dim(input_tensor,2)[0]) torch.equal(input_tensor[0:2,0:2], util.split_tensor_along_last_dim(input_tensor,2)[0])
torch.equal(input_tensor[2:,2:], util.split_tensor_along_last_dim(input_tensor,2)[1]) torch.equal(input_tensor[2:,2:], util.split_tensor_along_last_dim(input_tensor,2)[1])
def test_split_tensor_into_1d_equal_chunks():
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
input_tensor = torch.rand((3,4))
output_tensor = util.split_tensor_into_1d_equal_chunks(input_tensor)
if rank % 2 == 0 :
start = 0
end = int(input_tensor.numel()/2)
else :
start = int(input_tensor.numel()/2)
end = input_tensor.numel()
assert torch.equal(output_tensor, input_tensor.flatten()[start:end])
Utils.destroy_model_parallel()
def test_gather_split_1d_tensor():
Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
input_tensor = torch.ones((2,4)).cuda() * rank
actual_output_tensor = util.gather_split_1d_tensor(input_tensor)
if rank %2 == 0:
expected_output_tensor = torch.concat((input_tensor.flatten(), input_tensor.flatten() + 1))
else :
expected_output_tensor = torch.concat((input_tensor.flatten() - 1, input_tensor.flatten()))
assert(torch.equal(actual_output_tensor, expected_output_tensor))
Utils.destroy_model_parallel()
def test_vocab():
global_vocab_size = 1600
per_partition_vocab_size = 1600 / Utils.world_size
assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_per_partition_vocab_size(global_vocab_size // Utils.world_size, rank, Utils.world_size)))
assert((rank * per_partition_vocab_size, (rank + 1)* per_partition_vocab_size) == (util.VocabUtility.vocab_range_from_global_vocab_size(global_vocab_size, rank, Utils.world_size)))
\ No newline at end of file
import os
import torch import torch
import megatron.core.parallel_state as ps import megatron.core.parallel_state as ps
from datetime import timedelta
import pytest import pytest
from tests.test_utilities import Utils
import os
rank = Utils.rank
world_size = Utils.world_size
world_size = torch.cuda.device_count() def test_initialize__and_destroy_model_parallel():
rank = int(os.environ['LOCAL_RANK'])
print('Ranks is : ' + str(rank))
def initialize_distributed():
print(f'Initializing torch.distributed with rank: {rank}, world_size: {world_size}')
torch.cuda.set_device(rank % torch.cuda.device_count())
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank, init_method=init_method, timeout=timedelta(seconds=10))
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size = None,
pipeline_model_parallel_split_rank = None,
):
# This might not be the right way to do this.
try:
ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank)
except:
ps.destroy_model_parallel()
ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank)
pass
def test_initialize_model_parallel():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert(ps.initialize_model_parallel()) assert(ps.initialize_model_parallel())
initialize_distributed() Utils.initialize_distributed()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size)) assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
...@@ -44,124 +19,86 @@ def test_initialize_model_parallel(): ...@@ -44,124 +19,86 @@ def test_initialize_model_parallel():
assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size)) assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2)) assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))
initialize_model_parallel() Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
assert(ps.model_parallel_is_initialized()) assert(ps.model_parallel_is_initialized())
assert(ps.get_model_parallel_group() is not None) assert(ps.get_model_parallel_group() is not None)
assert(ps.get_tensor_model_parallel_group() is not None) assert(ps.get_tensor_model_parallel_group() is not None)
assert(ps.get_pipeline_model_parallel_group() is not None) assert(ps.get_pipeline_model_parallel_group() is not None)
assert(ps.get_data_parallel_group() is not None) assert(ps.get_data_parallel_group() is not None)
assert(ps.get_embedding_group() is not None) Utils.destroy_model_parallel()
assert(ps.get_position_embedding_group() is not None) assert(ps._MODEL_PARALLEL_GROUP is None)
ps.destroy_model_parallel()
def test_pipeline_parallel_initializations(): def test_pipeline_parallel_initializations():
initialize_model_parallel(pipeline_model_parallel_size=2) Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
assert(ps.get_pipeline_model_parallel_first_rank() == 0) assert(ps.get_pipeline_model_parallel_first_rank() == rank % 2 )
assert(ps.get_data_parallel_src_rank() == rank) assert(ps.get_data_parallel_src_rank() == rank)
assert(ps.get_pipeline_model_parallel_next_rank() == 0 if rank == world_size - 1 else rank + 1) assert(ps.get_pipeline_model_parallel_next_rank() == ((rank + 2) % world_size))
assert(ps.get_pipeline_model_parallel_prev_rank() == rank - 1 if rank > 0 else world_size - 1) assert(ps.get_pipeline_model_parallel_prev_rank() == ((rank - 2) % world_size))
ps.destroy_model_parallel() Utils.destroy_model_parallel()
def test_data_parallel_initializations(): def test_data_parallel_initializations():
initialize_model_parallel(pipeline_model_parallel_size=world_size) Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_data_parallel_src_rank() == rank) assert(ps.get_data_parallel_src_rank() == rank)
assert(ps.get_data_parallel_world_size() == world_size-1) assert(ps.get_data_parallel_world_size() == 1)
assert(ps.get_data_parallel_rank() == 0) assert(ps.get_data_parallel_rank() == 0)
ps.destroy_model_parallel() Utils.destroy_model_parallel()
def test_tensor_model_parellel_world_size(): def test_tensor_model_parellel_world_size():
initialize_model_parallel(tensor_model_parallel_size=world_size) Utils.initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_world_size() == world_size) assert(ps.get_tensor_model_parallel_world_size() == world_size)
ps.set_tensor_model_parallel_world_size(None) ps.set_tensor_model_parallel_world_size(None)
assert(ps.get_tensor_model_parallel_world_size() == world_size) assert(ps.get_tensor_model_parallel_world_size() == world_size)
ps.destroy_model_parallel() Utils.destroy_model_parallel()
def test_pipeline_model_parallel_world_size(): def test_pipeline_model_parallel_world_size():
initialize_model_parallel(pipeline_model_parallel_size=world_size) Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_pipeline_model_parallel_world_size() == world_size) assert(ps.get_pipeline_model_parallel_world_size() == world_size)
ps.set_pipeline_model_parallel_world_size(None) ps.set_pipeline_model_parallel_world_size(None)
assert(ps.get_pipeline_model_parallel_world_size() == world_size) assert(ps.get_pipeline_model_parallel_world_size() == world_size)
ps.destroy_model_parallel() Utils.destroy_model_parallel()
def test_tensor_model_parallel_rank(): def test_tensor_model_parallel_rank():
initialize_model_parallel(tensor_model_parallel_size=world_size) Utils.initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_rank() == rank) assert(ps.get_tensor_model_parallel_rank() == rank)
ps.set_tensor_model_parallel_rank(None) ps.set_tensor_model_parallel_rank(None)
assert(ps.get_tensor_model_parallel_rank() == rank) assert(ps.get_tensor_model_parallel_rank() == rank)
ps.destroy_model_parallel() Utils.destroy_model_parallel()
def test_pipeline_model_parallel_rank(): def test_pipeline_model_parallel_rank():
initialize_model_parallel(pipeline_model_parallel_size=world_size) Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_pipeline_model_parallel_rank() == rank) assert(ps.get_pipeline_model_parallel_rank() == rank)
ps.set_pipeline_model_parallel_rank(None) ps.set_pipeline_model_parallel_rank(None)
assert(ps.get_pipeline_model_parallel_rank() == rank) assert(ps.get_pipeline_model_parallel_rank() == rank)
ps.destroy_model_parallel() Utils.destroy_model_parallel()
def test_is_pipeline_first_stage(): def test_is_pipeline_first_stage():
initialize_model_parallel(pipeline_model_parallel_size=world_size) Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0)) assert(ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0))
assert(ps.is_pipeline_first_stage() == (rank == 0)) assert(ps.is_pipeline_first_stage() == (rank == 0))
ps.destroy_model_parallel() Utils.destroy_model_parallel()
def test_is_pipeline_last_stage(): def test_is_pipeline_last_stage():
initialize_model_parallel(pipeline_model_parallel_size=world_size) Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size-1)) assert(ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size-1))
assert(ps.is_pipeline_last_stage() == (rank == world_size-1)) assert(ps.is_pipeline_last_stage() == (rank == world_size-1))
ps.destroy_model_parallel() Utils.destroy_model_parallel()
def test_virtual_pipeline_model_parallel_rank(): def test_virtual_pipeline_model_parallel_rank():
initialize_model_parallel(pipeline_model_parallel_size=world_size) Utils.initialize_model_parallel(pipeline_model_parallel_size=world_size)
ps.set_virtual_pipeline_model_parallel_rank(rank) ps.set_virtual_pipeline_model_parallel_rank(rank)
assert(ps.get_virtual_pipeline_model_parallel_rank() == rank) assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
ps.destroy_model_parallel() Utils.destroy_model_parallel()
def test_get_tensor_model_parallel_src_rank(): def test_get_tensor_model_parallel_src_rank():
initialize_model_parallel(tensor_model_parallel_size=world_size) Utils.initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size)) assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))
ps.destroy_model_parallel() Utils.destroy_model_parallel()
\ No newline at end of file
"""
def test_get_virtual_pipeline_model_parallel_world_size():
initialize_model_parallel(pipeline_model_parallel_size=world_size)
ps.set_virtual_pipeline_model_parallel_rank(world_size)
assert(ps.get_virtual_pipeline_model_parallel_world_size() == world_size)
ps.destroy_model_parallel()
def test_is_rank_in_embedding_group():
assert(ps.is_rank_in_embedding_group(ignore_virtual=True) == (rank in ps._EMBEDDING_GLOBAL_RANKS))
if rank in ps._EMBEDDING_GLOBAL_RANKS:
assert(ps.is_rank_in_embedding_group() == ps.is_pipeline_first_stage())
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
assert(ps.is_rank_in_embedding_group() == ps.is_pipeline_last_stage())
else:
assert(ps.is_rank_in_embedding_group())
def test_is_rank_in_position_embedding_group():
assert(ps.is_rank_in_position_embedding_group() == (rank in ps._POSITION_EMBEDDING_GLOBAL_RANKS))
def test_is_pipeline_stage_before_split():
if world_size == 1:
assert(ps.is_pipeline_stage_before_split())
# TODO: Changes here for more than one world size
assert(ps.is_pipeline_stage_before_split())
def test_is_pipeline_stage_after_split():
if world_size == 1:
assert(ps.is_pipeline_stage_after_split())
# TODO: Changes here for more than one world size
assert(ps.is_pipeline_stage_before_split())
def test_is_pipeline_stage_at_split():
assert(
ps.is_pipeline_stage_at_split() ==
(ps.is_pipeline_stage_before_split(rank) and ps.is_pipeline_stage_after_split(rank+1))
)
def test_destroy_model_parallel():
ps.destroy_model_parallel()
assert(ps._MODEL_PARALLEL_GROUP is None)
"""
\ No newline at end of file
import os
import torch
import megatron.core.parallel_state as ps
class Utils:
world_size = torch.cuda.device_count()
rank = int(os.environ['LOCAL_RANK'])
@staticmethod
def initialize_distributed():
print(f'Initializing torch.distributed with rank: {Utils.rank}, world_size: {Utils.world_size}')
torch.cuda.set_device(Utils.rank % torch.cuda.device_count())
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(backend='nccl', world_size=Utils.world_size, rank=Utils.rank, init_method=init_method)
@staticmethod
def destroy_model_parallel():
ps.destroy_model_parallel()
torch.distributed.barrier()
@staticmethod
def initialize_model_parallel(tensor_model_parallel_size = 1, pipeline_model_parallel_size = 1, virtual_pipeline_model_parallel_size = None, pipeline_model_parallel_split_rank = None):
ps.destroy_model_parallel()
if not torch.distributed.is_initialized():
Utils.initialize_distributed()
ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment