Commit 2fd9ea1a authored by shanmugamr's avatar shanmugamr
Browse files

Code covearage

parent 423623cb
File deleted
...@@ -5,7 +5,8 @@ test: ...@@ -5,7 +5,8 @@ test:
- docker - docker
script: script:
- nvidia-smi - nvidia-smi
- torchrun --nproc_per_node=2 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/ - torchrun --nproc_per_node=2 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/'
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
artifacts: artifacts:
paths: paths:
- coverage - coverage
......
import os
import torch
def main():
rank = torch.cuda.current_device()
world_size = torch.cuda.device_count()
print(f'Initializing torch.distributed with rank: {rank}, world_size: {world_size}')
torch.cuda.set_device(rank % torch.cuda.device_count())
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank, init_method=init_method)
if __name__ == '__main__':
main()
...@@ -18,6 +18,20 @@ def initialize_distributed(): ...@@ -18,6 +18,20 @@ def initialize_distributed():
init_method += master_ip + ':' + master_port init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank, init_method=init_method, timeout=timedelta(seconds=10)) torch.distributed.init_process_group(backend='nccl', world_size=world_size, rank=rank, init_method=init_method, timeout=timedelta(seconds=10))
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size = None,
pipeline_model_parallel_split_rank = None,
):
# This might not be the right way to do this.
try:
ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank)
except:
ps.destroy_model_parallel()
ps.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank)
pass
def test_initialize_model_parallel(): def test_initialize_model_parallel():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
assert(ps.initialize_model_parallel()) assert(ps.initialize_model_parallel())
...@@ -30,7 +44,7 @@ def test_initialize_model_parallel(): ...@@ -30,7 +44,7 @@ def test_initialize_model_parallel():
assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size)) assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2)) assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))
ps.initialize_model_parallel() initialize_model_parallel()
assert(ps.model_parallel_is_initialized()) assert(ps.model_parallel_is_initialized())
assert(ps.get_model_parallel_group() is not None) assert(ps.get_model_parallel_group() is not None)
...@@ -42,24 +56,22 @@ def test_initialize_model_parallel(): ...@@ -42,24 +56,22 @@ def test_initialize_model_parallel():
ps.destroy_model_parallel() ps.destroy_model_parallel()
def test_pipeline_parallel_initializations(): def test_pipeline_parallel_initializations():
ps.initialize_model_parallel(pipeline_model_parallel_size=2) initialize_model_parallel(pipeline_model_parallel_size=2)
assert(ps.get_pipeline_model_parallel_first_rank() == 0) assert(ps.get_pipeline_model_parallel_first_rank() == 0)
assert(ps.get_data_parallel_src_rank() == rank) assert(ps.get_data_parallel_src_rank() == rank)
assert(ps.get_pipeline_model_parallel_next_rank() == 0 if rank == world_size - 1 else rank + 1) assert(ps.get_pipeline_model_parallel_next_rank() == 0 if rank == world_size - 1 else rank + 1)
assert(ps.get_pipeline_model_parallel_prev_rank() == rank - 1 if rank > 0 else 1) assert(ps.get_pipeline_model_parallel_prev_rank() == rank - 1 if rank > 0 else world_size - 1)
assert(ps.get_data_parallel_world_size() == world_size-1)
assert(ps.get_data_parallel_rank() == 0)
ps.destroy_model_parallel() ps.destroy_model_parallel()
def test_data_parallel_initializations(): def test_data_parallel_initializations():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size) initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_data_parallel_src_rank() == rank) assert(ps.get_data_parallel_src_rank() == rank)
assert(ps.get_data_parallel_world_size() == world_size-1) assert(ps.get_data_parallel_world_size() == world_size-1)
assert(ps.get_data_parallel_rank() == 0) assert(ps.get_data_parallel_rank() == 0)
ps.destroy_model_parallel() ps.destroy_model_parallel()
def test_tensor_model_parellel_world_size(): def test_tensor_model_parellel_world_size():
ps.initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_world_size() == world_size) assert(ps.get_tensor_model_parallel_world_size() == world_size)
ps.set_tensor_model_parallel_world_size(None) ps.set_tensor_model_parallel_world_size(None)
assert(ps.get_tensor_model_parallel_world_size() == world_size) assert(ps.get_tensor_model_parallel_world_size() == world_size)
...@@ -67,7 +79,7 @@ def test_tensor_model_parellel_world_size(): ...@@ -67,7 +79,7 @@ def test_tensor_model_parellel_world_size():
def test_pipeline_model_parallel_world_size(): def test_pipeline_model_parallel_world_size():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size) initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_pipeline_model_parallel_world_size() == world_size) assert(ps.get_pipeline_model_parallel_world_size() == world_size)
ps.set_pipeline_model_parallel_world_size(None) ps.set_pipeline_model_parallel_world_size(None)
assert(ps.get_pipeline_model_parallel_world_size() == world_size) assert(ps.get_pipeline_model_parallel_world_size() == world_size)
...@@ -75,40 +87,40 @@ def test_pipeline_model_parallel_world_size(): ...@@ -75,40 +87,40 @@ def test_pipeline_model_parallel_world_size():
def test_tensor_model_parallel_rank(): def test_tensor_model_parallel_rank():
ps.initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_rank() == rank) assert(ps.get_tensor_model_parallel_rank() == rank)
ps.set_tensor_model_parallel_rank(None) ps.set_tensor_model_parallel_rank(None)
assert(ps.get_tensor_model_parallel_rank() == rank) assert(ps.get_tensor_model_parallel_rank() == rank)
ps.destroy_model_parallel() ps.destroy_model_parallel()
def test_pipeline_model_parallel_rank(): def test_pipeline_model_parallel_rank():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size) initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.get_pipeline_model_parallel_rank() == rank) assert(ps.get_pipeline_model_parallel_rank() == rank)
ps.set_pipeline_model_parallel_rank(None) ps.set_pipeline_model_parallel_rank(None)
assert(ps.get_pipeline_model_parallel_rank() == rank) assert(ps.get_pipeline_model_parallel_rank() == rank)
ps.destroy_model_parallel() ps.destroy_model_parallel()
def test_is_pipeline_first_stage(): def test_is_pipeline_first_stage():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size) initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0)) assert(ps.is_pipeline_first_stage(ignore_virtual=True) == (rank == 0))
assert(ps.is_pipeline_first_stage() == (rank == 0)) assert(ps.is_pipeline_first_stage() == (rank == 0))
ps.destroy_model_parallel() ps.destroy_model_parallel()
def test_is_pipeline_last_stage(): def test_is_pipeline_last_stage():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size) initialize_model_parallel(pipeline_model_parallel_size=world_size)
assert(ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size-1)) assert(ps.is_pipeline_last_stage(ignore_virtual=True) == (rank == world_size-1))
assert(ps.is_pipeline_last_stage() == (rank == world_size-1)) assert(ps.is_pipeline_last_stage() == (rank == world_size-1))
ps.destroy_model_parallel() ps.destroy_model_parallel()
def test_virtual_pipeline_model_parallel_rank(): def test_virtual_pipeline_model_parallel_rank():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size) initialize_model_parallel(pipeline_model_parallel_size=world_size)
ps.set_virtual_pipeline_model_parallel_rank(rank) ps.set_virtual_pipeline_model_parallel_rank(rank)
assert(ps.get_virtual_pipeline_model_parallel_rank() == rank) assert(ps.get_virtual_pipeline_model_parallel_rank() == rank)
ps.destroy_model_parallel() ps.destroy_model_parallel()
def test_get_tensor_model_parallel_src_rank(): def test_get_tensor_model_parallel_src_rank():
ps.initialize_model_parallel(tensor_model_parallel_size=world_size) initialize_model_parallel(tensor_model_parallel_size=world_size)
assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size)) assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))
ps.destroy_model_parallel() ps.destroy_model_parallel()
...@@ -121,7 +133,7 @@ def test_global_memory_buffer(): ...@@ -121,7 +133,7 @@ def test_global_memory_buffer():
""" """
def test_get_virtual_pipeline_model_parallel_world_size(): def test_get_virtual_pipeline_model_parallel_world_size():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size) initialize_model_parallel(pipeline_model_parallel_size=world_size)
ps.set_virtual_pipeline_model_parallel_rank(world_size) ps.set_virtual_pipeline_model_parallel_rank(world_size)
assert(ps.get_virtual_pipeline_model_parallel_world_size() == world_size) assert(ps.get_virtual_pipeline_model_parallel_world_size() == world_size)
ps.destroy_model_parallel() ps.destroy_model_parallel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment