Commit 56934a2d authored by shanmugamr's avatar shanmugamr
Browse files

Adding some basic unit tests

parent 423623cb
No preview for this file type
image: gitlab-master.nvidia.com/dl/dgx/pytorch:21.12-py3-devel image: gitlab-master.nvidia.com/dl/dgx/pytorch:21.12-py3-devel
[runners.docker]
gpus = "all"
test: test:
tags: tags:
- docker - docker
......
...@@ -29,7 +29,7 @@ def test_initialize_model_parallel(): ...@@ -29,7 +29,7 @@ def test_initialize_model_parallel():
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size)) assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2)) assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=world_size))
ps.initialize_model_parallel() ps.initialize_model_parallel()
assert(ps.model_parallel_is_initialized()) assert(ps.model_parallel_is_initialized())
...@@ -112,12 +112,6 @@ def test_get_tensor_model_parallel_src_rank(): ...@@ -112,12 +112,6 @@ def test_get_tensor_model_parallel_src_rank():
assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size)) assert(ps.get_tensor_model_parallel_src_rank() == ((rank // world_size) * world_size))
ps.destroy_model_parallel() ps.destroy_model_parallel()
def test_global_memory_buffer():
ps._GLOBAL_MEMORY_BUFFER = None
ps._set_global_memory_buffer()
assert(ps.get_global_memory_buffer() is not None)
""" """
def test_get_virtual_pipeline_model_parallel_world_size(): def test_get_virtual_pipeline_model_parallel_world_size():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment