Commit b69e2195 authored by shanmugamr's avatar shanmugamr
Browse files

Adding some basic unit tests

parent 6ab70f5c
image: gitlab-master.nvidia.com/dl/dgx/pytorch:21.12-py3-devel
test:
tags:
- docker
script:
- python -m pytest --cov-report term --cov-report=html --cov=megatron/core tests/
- torchrun --nproc_per_node=2 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
artifacts:
paths:
- coverage
......
......@@ -99,7 +99,7 @@ def initialize_model_parallel(
num_data_parallel_groups: int = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size_ > 2:
if not pipeline_model_parallel_size > 2:
raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule")
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
......
import torch
import megatron.core.tensor_parallel.utils as util
def test_split_tensor_along_last_dim():
input_tensor = torch.rand((3,4))
torch.equal(input_tensor[0:2,0:2], util.split_tensor_along_last_dim(input_tensor,2)[0])
torch.equal(input_tensor[2:,2:], util.split_tensor_along_last_dim(input_tensor,2)[1])
......@@ -4,16 +4,12 @@ import megatron.core.parallel_state as ps
from datetime import timedelta
import pytest
#TODO: Maybe get these values frome environment variables
rank = torch.cuda.current_device()
world_size = 1 #torch.cuda.device_count()
tensor_model_parallel_size = 1
pipeline_model_parallel_size = 1
virtual_pipeline_model_parallel_size = None
pipeline_model_parallel_split_rank = None
world_size = torch.cuda.device_count()
rank = int(os.environ['LOCAL_RANK'])
print('Ranks is : ' + str(rank))
def initialize_distributed():
rank = torch.cuda.current_device()
print(f'Initializing torch.distributed with rank: {rank}, world_size: {world_size}')
torch.cuda.set_device(rank % torch.cuda.device_count())
init_method = 'tcp://'
......@@ -27,12 +23,15 @@ def test_initialize_model_parallel():
assert(ps.initialize_model_parallel())
initialize_distributed()
with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(tensor_model_parallel_size=2))
assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))
with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))
with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))
with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2))
assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))
ps.initialize_model_parallel()
def test_other_initializations():
assert(ps.model_parallel_is_initialized())
assert(ps.get_model_parallel_group() is not None)
assert(ps.get_tensor_model_parallel_group() is not None)
......@@ -40,49 +39,94 @@ def test_other_initializations():
assert(ps.get_data_parallel_group() is not None)
assert(ps.get_embedding_group() is not None)
assert(ps.get_position_embedding_group() is not None)
#TODO : Should change some of these test below to actually test code
ps.destroy_model_parallel()
def test_pipeline_parallel_initializations():
ps.initialize_model_parallel(pipeline_model_parallel_size=2)
assert(ps.get_pipeline_model_parallel_first_rank() == 0)
assert(ps.get_data_parallel_src_rank() == 0)
assert(ps.get_pipeline_model_parallel_next_rank() == 0)
assert(ps.get_pipeline_model_parallel_prev_rank() == 0)
assert(ps.get_data_parallel_world_size() == world_size)
assert(ps.get_data_parallel_src_rank() == rank)
assert(ps.get_pipeline_model_parallel_next_rank() == 0 if rank == world_size - 1 else rank + 1)
assert(ps.get_pipeline_model_parallel_prev_rank() == rank - 1 if rank > 0 else 1)
assert(ps.get_data_parallel_world_size() == world_size-1)
assert(ps.get_data_parallel_rank() == 0)
ps.destroy_model_parallel()
def test_data_parallel_initializations():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_data_parallel_src_rank() == rank)
assert(ps.get_data_parallel_world_size() == world_size-1)
assert(ps.get_data_parallel_rank() == 0)
ps.destroy_model_parallel()
def test_tensor_model_parellel_world_size():
ps.set_tensor_model_parallel_world_size(world_size)
ps.initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_world_size() == world_size)
ps.set_tensor_model_parallel_world_size(None)
assert(ps.get_tensor_model_parallel_world_size() == world_size)
ps.destroy_model_parallel()
def test_pipeline_model_parallel_world_size():
ps.set_pipeline_model_parallel_world_size(world_size)
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_pipeline_model_parallel_world_size() == world_size)
ps.set_pipeline_model_parallel_world_size(None)
assert(ps.get_pipeline_model_parallel_world_size() == world_size)
ps.destroy_model_parallel()
def test_tensor_model_parallel_rank():
ps.set_tensor_model_parallel_rank(rank)
ps.initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_rank() == rank)
ps.set_tensor_model_parallel_rank(None)
assert(ps.get_tensor_model_parallel_rank() == rank)
ps.destroy_model_parallel()
def test_tensor_model_parallel_rank():
ps.set_pipeline_model_parallel_rank(rank)
def test_pipeline_model_parallel_rank():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_pipeline_model_parallel_rank() == rank)
ps.set_pipeline_model_parallel_rank(None)
assert(ps.get_pipeline_model_parallel_rank() == rank)
ps.destroy_model_parallel()
def test_is_pipeline_first_stage():
assert(ps.is_pipeline_first_stage(ignore_virtual=True))
assert(ps.is_pipeline_first_stage())
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0))
assert(ps.is_pipeline_first_stage() == (rank == 0))
ps.destroy_model_parallel()
def test_is_pipeline_last_stage():
assert(
ps.is_pipeline_last_stage(ignore_virtual=True) == (ps.get_pipeline_model_parallel_rank() == world_size-1)
)
assert(
ps.is_pipeline_last_stage() == (ps.get_pipeline_model_parallel_rank() == world_size-1)
)
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size-1))
assert(ps.is_pipeline_last_stage() == (rank == world_size-1))
ps.destroy_model_parallel()
def test_virtual_pipeline_model_parallel_rank():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
ps.set_virtual_pipeline_model_parallel_rank(rank)
assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
ps.destroy_model_parallel()
def test_get_tensor_model_parallel_src_rank():
ps.initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))
ps.destroy_model_parallel()
def test_global_memory_buffer():
ps._GLOBAL_MEMORY_BUFFER = None
ps._set_global_memory_buffer()
assert(ps.get_global_memory_buffer() is not None)
"""
def test_get_virtual_pipeline_model_parallel_world_size():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
ps.set_virtual_pipeline_model_parallel_rank(world_size)
assert(ps.get_virtual_pipeline_model_parallel_world_size() == world_size)
ps.destroy_model_parallel()
def test_is_rank_in_embedding_group():
assert(ps.is_rank_in_embedding_group(ignore_virtual=True) == (rank in ps._EMBEDDING_GLOBAL_RANKS))
......@@ -114,20 +158,7 @@ def test_is_pipeline_stage_at_split():
(ps.is_pipeline_stage_before_split(rank) and ps.is_pipeline_stage_after_split(rank+1))
)
def test_virtual_pipeline_model_parallel_rank():
ps.set_virtual_pipeline_model_parallel_rank(rank)
assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
def test_virtual_pipeline_model_parallel_rank():
ps.set_virtual_pipeline_model_parallel_rank(rank)
assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
def test_get_virtual_pipeline_model_parallel_world_size():
assert(ps.get_virtual_pipeline_model_parallel_world_size() == virtual_pipeline_model_parallel_size)
def test_get_tensor_model_parallel_src_rank():
assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))
def global_memory_buffer():
ps._set_global_memory_buffer()
assert(ps.get_global_memory_buffer() is not None)
\ No newline at end of file
def test_destroy_model_parallel():
ps.destroy_model_parallel()
assert(ps._MODEL_PARALLEL_GROUP is None)
"""
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment