Commit 6ab70f5c authored by shanmugamr's avatar shanmugamr
Browse files

Adding some basic unit tests

parent fb8c09eb
import os
import torch
def main():
rank = torch.cuda.current_device()
world_size = torch.cuda.device_count()
print(f'Initializing torch.distributed with rank: {rank}, world_size: {world_size}')
torch.cuda.set_device(rank % torch.cuda.device_count())
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank, init_method=init_method)
if __name__ == '__main__':
main()
import os
import torch
import megatron.core.parallel_state as ps
from datetime import timedelta
import pytest
#TODO: Maybe get these values frome environment variables
rank = torch.cuda.current_device()
world_size = 1 #torch.cuda.device_count()
tensor_model_parallel_size = 1
pipeline_model_parallel_size = 1
virtual_pipeline_model_parallel_size = None
pipeline_model_parallel_split_rank = None
def initialize_distributed():
rank = torch.cuda.current_device()
print(f'Initializing torch.distributed with rank: {rank}, world_size: {world_size}')
torch.cuda.set_device(rank % torch.cuda.device_count())
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank, init_method=init_method, timeout=timedelta(seconds=10))
def test_initialize_model_parallel():
with pytest.raises(AssertionError):
assert(ps.initialize_model_parallel())
initialize_distributed()
with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(tensor_model_parallel_size=2))
with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2))
ps.initialize_model_parallel()
def test_other_initializations():
assert(ps.model_parallel_is_initialized())
assert(ps.get_model_parallel_group() is not None)
assert(ps.get_tensor_model_parallel_group() is not None)
assert(ps.get_pipeline_model_parallel_group() is not None)
assert(ps.get_data_parallel_group() is not None)
assert(ps.get_embedding_group() is not None)
assert(ps.get_position_embedding_group() is not None)
#TODO : Should change some of these test below to actually test code
assert(ps.get_pipeline_model_parallel_first_rank() == 0)
assert(ps.get_data_parallel_src_rank() == 0)
assert(ps.get_pipeline_model_parallel_next_rank() == 0)
assert(ps.get_pipeline_model_parallel_prev_rank() == 0)
assert(ps.get_data_parallel_world_size() == world_size)
assert(ps.get_data_parallel_rank() == 0)
def test_tensor_model_parellel_world_size():
ps.set_tensor_model_parallel_world_size(world_size)
assert(ps.get_tensor_model_parallel_world_size() == world_size)
ps.set_tensor_model_parallel_world_size(None)
assert(ps.get_tensor_model_parallel_world_size() == world_size)
def test_pipeline_model_parallel_world_size():
ps.set_pipeline_model_parallel_world_size(world_size)
assert(ps.get_pipeline_model_parallel_world_size() == world_size)
ps.set_pipeline_model_parallel_world_size(None)
assert(ps.get_pipeline_model_parallel_world_size() == world_size)
def test_tensor_model_parallel_rank():
ps.set_tensor_model_parallel_rank(rank)
assert(ps.get_tensor_model_parallel_rank() == rank)
ps.set_tensor_model_parallel_rank(None)
assert(ps.get_tensor_model_parallel_rank() == rank)
def test_tensor_model_parallel_rank():
ps.set_pipeline_model_parallel_rank(rank)
assert(ps.get_pipeline_model_parallel_rank() == rank)
ps.set_pipeline_model_parallel_rank(None)
assert(ps.get_pipeline_model_parallel_rank() == rank)
def test_is_pipeline_first_stage():
assert(ps.is_pipeline_first_stage(ignore_virtual=True))
assert(ps.is_pipeline_first_stage())
def test_is_pipeline_last_stage():
assert(
ps.is_pipeline_last_stage(ignore_virtual=True) == (ps.get_pipeline_model_parallel_rank() == world_size-1)
)
assert(
ps.is_pipeline_last_stage() == (ps.get_pipeline_model_parallel_rank() == world_size-1)
)
def test_is_rank_in_embedding_group():
assert(ps.is_rank_in_embedding_group(ignore_virtual=True) == (rank in ps._EMBEDDING_GLOBAL_RANKS))
if rank in ps._EMBEDDING_GLOBAL_RANKS:
assert(ps.is_rank_in_embedding_group() == ps.is_pipeline_first_stage())
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
assert(ps.is_rank_in_embedding_group() == ps.is_pipeline_last_stage())
else:
assert(ps.is_rank_in_embedding_group())
def test_is_rank_in_position_embedding_group():
assert(ps.is_rank_in_position_embedding_group() == (rank in ps._POSITION_EMBEDDING_GLOBAL_RANKS))
def test_is_pipeline_stage_before_split():
if world_size == 1:
assert(ps.is_pipeline_stage_before_split())
# TODO: Changes here for more than one world size
assert(ps.is_pipeline_stage_before_split())
def test_is_pipeline_stage_after_split():
if world_size == 1:
assert(ps.is_pipeline_stage_after_split())
# TODO: Changes here for more than one world size
assert(ps.is_pipeline_stage_before_split())
def test_is_pipeline_stage_at_split():
assert(
ps.is_pipeline_stage_at_split() ==
(ps.is_pipeline_stage_before_split(rank) and ps.is_pipeline_stage_after_split(rank+1))
)
def test_virtual_pipeline_model_parallel_rank():
ps.set_virtual_pipeline_model_parallel_rank(rank)
assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
def test_virtual_pipeline_model_parallel_rank():
ps.set_virtual_pipeline_model_parallel_rank(rank)
assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
def test_get_virtual_pipeline_model_parallel_world_size():
assert(ps.get_virtual_pipeline_model_parallel_world_size() == virtual_pipeline_model_parallel_size)
def test_get_tensor_model_parallel_src_rank():
assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))
def global_memory_buffer():
ps._set_global_memory_buffer()
assert(ps.get_global_memory_buffer() is not None)
\ No newline at end of file
import pytest
import torch
import megatron.core.utils as util
import numpy as np
def test_divide_properly():
assert util.divide(4,2) == 2
def test_divide_improperly():
with pytest.raises(AssertionError):
util.divide(4,5)
def test_global_memory_buffer():
global_memory_buffer = util.GlobalMemoryBuffer()
obtained_tensor = global_memory_buffer.get_tensor((3,2), torch.float32, "test_tensor")
expected_tensor = torch.empty((3,2), dtype=torch.float32, device=torch.cuda.current_device())
assert torch.equal(obtained_tensor, expected_tensor)
def test_make_viewless_tensor():
inp = torch.rand((3,4))
assert(torch.equal(inp, util.make_viewless_tensor(inp, True, True)))
assert(torch.equal(inp, util.make_viewless_tensor(inp, True, False)))
def test_safely_set_viewless_tensor_data():
tensor = torch.zeros((3,4))
new_data_tensor = torch.tensor(np.random.rand(3,4))
util.safely_set_viewless_tensor_data(tensor, new_data_tensor)
assert(torch.equal(tensor, new_data_tensor))
def test_assert_viewless_tensor():
tensor = torch.rand((3,4))
assert(torch.equal(util.assert_viewless_tensor(tensor), tensor))
input_tensor_list=[tensor,tensor,tensor]
output_tensor_list = util.assert_viewless_tensor(input_tensor_list)
for inp,out in zip(input_tensor_list, output_tensor_list):
assert(torch.equal(inp,out))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment