Commit b69e2195 authored by shanmugamr's avatar shanmugamr
Browse files

Adding some basic unit tests

parent 6ab70f5c
image: gitlab-master.nvidia.com/dl/dgx/pytorch:21.12-py3-devel image: gitlab-master.nvidia.com/dl/dgx/pytorch:21.12-py3-devel
test: test:
tags:
- docker
script: script:
- python -m pytest --cov-report term --cov-report=html --cov=megatron/core tests/ - torchrun --nproc_per_node=2 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
artifacts: artifacts:
paths: paths:
- coverage - coverage
......
...@@ -99,7 +99,7 @@ def initialize_model_parallel( ...@@ -99,7 +99,7 @@ def initialize_model_parallel(
num_data_parallel_groups: int = world_size // data_parallel_size num_data_parallel_groups: int = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size is not None: if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size_ > 2: if not pipeline_model_parallel_size > 2:
raise RuntimeError("pipeline-model-parallel size should be greater than 2 with " raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule") "interleaved schedule")
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
......
import torch
import megatron.core.tensor_parallel.utils as util
def test_split_tensor_along_last_dim():
input_tensor = torch.rand((3,4))
torch.equal(input_tensor[0:2,0:2], util.split_tensor_along_last_dim(input_tensor,2)[0])
torch.equal(input_tensor[2:,2:], util.split_tensor_along_last_dim(input_tensor,2)[1])
...@@ -4,16 +4,12 @@ import megatron.core.parallel_state as ps ...@@ -4,16 +4,12 @@ import megatron.core.parallel_state as ps
from datetime import timedelta from datetime import timedelta
import pytest import pytest
#TODO: Maybe get these values frome environment variables
rank = torch.cuda.current_device() world_size = torch.cuda.device_count()
world_size = 1 #torch.cuda.device_count() rank = int(os.environ['LOCAL_RANK'])
tensor_model_parallel_size = 1 print('Ranks is : ' + str(rank))
pipeline_model_parallel_size = 1
virtual_pipeline_model_parallel_size = None
pipeline_model_parallel_split_rank = None
def initialize_distributed(): def initialize_distributed():
rank = torch.cuda.current_device()
print(f'Initializing torch.distributed with rank: {rank}, world_size: {world_size}') print(f'Initializing torch.distributed with rank: {rank}, world_size: {world_size}')
torch.cuda.set_device(rank % torch.cuda.device_count()) torch.cuda.set_device(rank % torch.cuda.device_count())
init_method = 'tcp://' init_method = 'tcp://'
...@@ -27,12 +23,15 @@ def test_initialize_model_parallel(): ...@@ -27,12 +23,15 @@ def test_initialize_model_parallel():
assert(ps.initialize_model_parallel()) assert(ps.initialize_model_parallel())
initialize_distributed() initialize_distributed()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(tensor_model_parallel_size=2)) assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))
with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))
with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2)) assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))
ps.initialize_model_parallel() ps.initialize_model_parallel()
def test_other_initializations():
assert(ps.model_parallel_is_initialized()) assert(ps.model_parallel_is_initialized())
assert(ps.get_model_parallel_group() is not None) assert(ps.get_model_parallel_group() is not None)
assert(ps.get_tensor_model_parallel_group() is not None) assert(ps.get_tensor_model_parallel_group() is not None)
...@@ -40,49 +39,94 @@ def test_other_initializations(): ...@@ -40,49 +39,94 @@ def test_other_initializations():
assert(ps.get_data_parallel_group() is not None) assert(ps.get_data_parallel_group() is not None)
assert(ps.get_embedding_group() is not None) assert(ps.get_embedding_group() is not None)
assert(ps.get_position_embedding_group() is not None) assert(ps.get_position_embedding_group() is not None)
#TODO : Should change some of these test below to actually test code ps.destroy_model_parallel()
def test_pipeline_parallel_initializations():
ps.initialize_model_parallel(pipeline_model_parallel_size=2)
assert(ps.get_pipeline_model_parallel_first_rank() == 0) assert(ps.get_pipeline_model_parallel_first_rank() == 0)
assert(ps.get_data_parallel_src_rank() == 0) assert(ps.get_data_parallel_src_rank() == rank)
assert(ps.get_pipeline_model_parallel_next_rank() == 0) assert(ps.get_pipeline_model_parallel_next_rank() == 0 if rank == world_size - 1 else rank + 1)
assert(ps.get_pipeline_model_parallel_prev_rank() == 0) assert(ps.get_pipeline_model_parallel_prev_rank() == rank - 1 if rank > 0 else 1)
assert(ps.get_data_parallel_world_size() == world_size) assert(ps.get_data_parallel_world_size() == world_size-1)
assert(ps.get_data_parallel_rank() == 0) assert(ps.get_data_parallel_rank() == 0)
ps.destroy_model_parallel()
def test_data_parallel_initializations():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_data_parallel_src_rank() == rank)
assert(ps.get_data_parallel_world_size() == world_size-1)
assert(ps.get_data_parallel_rank() == 0)
ps.destroy_model_parallel()
def test_tensor_model_parellel_world_size(): def test_tensor_model_parellel_world_size():
ps.set_tensor_model_parallel_world_size(world_size) ps.initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_world_size() == world_size) assert(ps.get_tensor_model_parallel_world_size() == world_size)
ps.set_tensor_model_parallel_world_size(None) ps.set_tensor_model_parallel_world_size(None)
assert(ps.get_tensor_model_parallel_world_size() == world_size) assert(ps.get_tensor_model_parallel_world_size() == world_size)
ps.destroy_model_parallel()
def test_pipeline_model_parallel_world_size(): def test_pipeline_model_parallel_world_size():
ps.set_pipeline_model_parallel_world_size(world_size) ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_pipeline_model_parallel_world_size() == world_size) assert(ps.get_pipeline_model_parallel_world_size() == world_size)
ps.set_pipeline_model_parallel_world_size(None) ps.set_pipeline_model_parallel_world_size(None)
assert(ps.get_pipeline_model_parallel_world_size() == world_size) assert(ps.get_pipeline_model_parallel_world_size() == world_size)
ps.destroy_model_parallel()
def test_tensor_model_parallel_rank(): def test_tensor_model_parallel_rank():
ps.set_tensor_model_parallel_rank(rank) ps.initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_rank() == rank) assert(ps.get_tensor_model_parallel_rank() == rank)
ps.set_tensor_model_parallel_rank(None) ps.set_tensor_model_parallel_rank(None)
assert(ps.get_tensor_model_parallel_rank() == rank) assert(ps.get_tensor_model_parallel_rank() == rank)
ps.destroy_model_parallel()
def test_tensor_model_parallel_rank(): def test_pipeline_model_parallel_rank():
ps.set_pipeline_model_parallel_rank(rank) ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_pipeline_model_parallel_rank() == rank) assert(ps.get_pipeline_model_parallel_rank() == rank)
ps.set_pipeline_model_parallel_rank(None) ps.set_pipeline_model_parallel_rank(None)
assert(ps.get_pipeline_model_parallel_rank() == rank) assert(ps.get_pipeline_model_parallel_rank() == rank)
ps.destroy_model_parallel()
def test_is_pipeline_first_stage(): def test_is_pipeline_first_stage():
assert(ps.is_pipeline_first_stage(ignore_virtual=True)) ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.is_pipeline_first_stage()) assert(ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0))
assert(ps.is_pipeline_first_stage() == (rank == 0))
ps.destroy_model_parallel()
def test_is_pipeline_last_stage(): def test_is_pipeline_last_stage():
assert( ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
ps.is_pipeline_last_stage(ignore_virtual=True) == (ps.get_pipeline_model_parallel_rank() == world_size-1) assert(ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size-1))
) assert(ps.is_pipeline_last_stage() == (rank == world_size-1))
assert( ps.destroy_model_parallel()
ps.is_pipeline_last_stage() == (ps.get_pipeline_model_parallel_rank() == world_size-1)
)
def test_virtual_pipeline_model_parallel_rank():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
ps.set_virtual_pipeline_model_parallel_rank(rank)
assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
ps.destroy_model_parallel()
def test_get_tensor_model_parallel_src_rank():
ps.initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))
ps.destroy_model_parallel()
def test_global_memory_buffer():
ps._GLOBAL_MEMORY_BUFFER = None
ps._set_global_memory_buffer()
assert(ps.get_global_memory_buffer() is not None)
"""
def test_get_virtual_pipeline_model_parallel_world_size():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
ps.set_virtual_pipeline_model_parallel_rank(world_size)
assert(ps.get_virtual_pipeline_model_parallel_world_size() == world_size)
ps.destroy_model_parallel()
def test_is_rank_in_embedding_group(): def test_is_rank_in_embedding_group():
assert(ps.is_rank_in_embedding_group(ignore_virtual=True) == (rank in ps._EMBEDDING_GLOBAL_RANKS)) assert(ps.is_rank_in_embedding_group(ignore_virtual=True) == (rank in ps._EMBEDDING_GLOBAL_RANKS))
...@@ -114,20 +158,7 @@ def test_is_pipeline_stage_at_split(): ...@@ -114,20 +158,7 @@ def test_is_pipeline_stage_at_split():
(ps.is_pipeline_stage_before_split(rank) and ps.is_pipeline_stage_after_split(rank+1)) (ps.is_pipeline_stage_before_split(rank) and ps.is_pipeline_stage_after_split(rank+1))
) )
def test_virtual_pipeline_model_parallel_rank(): def test_destroy_model_parallel():
ps.set_virtual_pipeline_model_parallel_rank(rank) ps.destroy_model_parallel()
assert(ps.get_virtual_pipeline_model_parallel_rank() == rank) assert(ps._MODEL_PARALLEL_GROUP is None)
"""
def test_virtual_pipeline_model_parallel_rank(): \ No newline at end of file
ps.set_virtual_pipeline_model_parallel_rank(rank)
assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
def test_get_virtual_pipeline_model_parallel_world_size():
assert(ps.get_virtual_pipeline_model_parallel_world_size() == virtual_pipeline_model_parallel_size)
def test_get_tensor_model_parallel_src_rank():
assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))
def global_memory_buffer():
ps._set_global_memory_buffer()
assert(ps.get_global_memory_buffer() is not None)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment