Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b69e2195
Commit
b69e2195
authored
Oct 06, 2022
by
shanmugamr
Browse files
Adding some basic unit tests
parent
6ab70f5c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
89 additions
and
49 deletions
+89
-49
.gitlab-ci.yml
.gitlab-ci.yml
+3
-1
megatron/core/parallel_state.py
megatron/core/parallel_state.py
+1
-1
tests/tensor_parallel/test_tensor_parallel_utils.py
tests/tensor_parallel/test_tensor_parallel_utils.py
+7
-0
tests/test_parallel_state.py
tests/test_parallel_state.py
+78
-47
No files found.
.gitlab-ci.yml
View file @
b69e2195
image
:
gitlab-master.nvidia.com/dl/dgx/pytorch:21.12-py3-devel
test
:
tags
:
-
docker
script
:
-
python
-m pytest --cov-report
term --cov-report=html --cov=megatron/core tests/
-
torchrun --nproc_per_node=2
-m pytest --cov-report
=
term --cov-report=html --cov=megatron/core tests/
artifacts
:
paths
:
-
coverage
...
...
megatron/core/parallel_state.py
View file @
b69e2195
...
...
@@ -99,7 +99,7 @@ def initialize_model_parallel(
num_data_parallel_groups
:
int
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size
is
not
None
:
if
not
pipeline_model_parallel_size
_
>
2
:
if
not
pipeline_model_parallel_size
>
2
:
raise
RuntimeError
(
"pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule"
)
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
...
...
tests/tensor_parallel/test_tensor_parallel_utils.py
0 → 100644
View file @
b69e2195
import
torch
import
megatron.core.tensor_parallel.utils
as
util
def
test_split_tensor_along_last_dim
():
input_tensor
=
torch
.
rand
((
3
,
4
))
torch
.
equal
(
input_tensor
[
0
:
2
,
0
:
2
],
util
.
split_tensor_along_last_dim
(
input_tensor
,
2
)[
0
])
torch
.
equal
(
input_tensor
[
2
:,
2
:],
util
.
split_tensor_along_last_dim
(
input_tensor
,
2
)[
1
])
tests/test_parallel_state.py
View file @
b69e2195
...
...
@@ -4,16 +4,12 @@ import megatron.core.parallel_state as ps
from
datetime
import
timedelta
import
pytest
#TODO: Maybe get these values frome environment variables
rank
=
torch
.
cuda
.
current_device
()
world_size
=
1
#torch.cuda.device_count()
tensor_model_parallel_size
=
1
pipeline_model_parallel_size
=
1
virtual_pipeline_model_parallel_size
=
None
pipeline_model_parallel_split_rank
=
None
world_size
=
torch
.
cuda
.
device_count
()
rank
=
int
(
os
.
environ
[
'LOCAL_RANK'
])
print
(
'Ranks is : '
+
str
(
rank
))
def
initialize_distributed
():
rank
=
torch
.
cuda
.
current_device
()
print
(
f
'Initializing torch.distributed with rank:
{
rank
}
, world_size:
{
world_size
}
'
)
torch
.
cuda
.
set_device
(
rank
%
torch
.
cuda
.
device_count
())
init_method
=
'tcp://'
...
...
@@ -27,12 +23,15 @@ def test_initialize_model_parallel():
assert
(
ps
.
initialize_model_parallel
())
initialize_distributed
()
with
pytest
.
raises
(
RuntimeError
):
assert
(
ps
.
initialize_model_parallel
(
tensor_model_parallel_size
=
2
))
assert
(
ps
.
initialize_model_parallel
(
tensor_model_parallel_size
=
2
*
world_size
))
with
pytest
.
raises
(
RuntimeError
):
assert
(
ps
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
2
*
world_size
))
with
pytest
.
raises
(
RuntimeError
):
assert
(
ps
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
,
tensor_model_parallel_size
=
world_size
))
with
pytest
.
raises
(
RuntimeError
):
assert
(
ps
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
2
))
assert
(
ps
.
initialize_model_parallel
(
virtual_
pipeline_model_parallel_size
=
2
))
ps
.
initialize_model_parallel
()
def
test_other_initializations
():
assert
(
ps
.
model_parallel_is_initialized
())
assert
(
ps
.
get_model_parallel_group
()
is
not
None
)
assert
(
ps
.
get_tensor_model_parallel_group
()
is
not
None
)
...
...
@@ -40,49 +39,94 @@ def test_other_initializations():
assert
(
ps
.
get_data_parallel_group
()
is
not
None
)
assert
(
ps
.
get_embedding_group
()
is
not
None
)
assert
(
ps
.
get_position_embedding_group
()
is
not
None
)
#TODO : Should change some of these test below to actually test code
ps
.
destroy_model_parallel
()
def
test_pipeline_parallel_initializations
():
ps
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
2
)
assert
(
ps
.
get_pipeline_model_parallel_first_rank
()
==
0
)
assert
(
ps
.
get_data_parallel_src_rank
()
==
0
)
assert
(
ps
.
get_pipeline_model_parallel_next_rank
()
==
0
)
assert
(
ps
.
get_pipeline_model_parallel_prev_rank
()
==
0
)
assert
(
ps
.
get_data_parallel_world_size
()
==
world_size
)
assert
(
ps
.
get_data_parallel_src_rank
()
==
rank
)
assert
(
ps
.
get_pipeline_model_parallel_next_rank
()
==
0
if
rank
==
world_size
-
1
else
rank
+
1
)
assert
(
ps
.
get_pipeline_model_parallel_prev_rank
()
==
rank
-
1
if
rank
>
0
else
1
)
assert
(
ps
.
get_data_parallel_world_size
()
==
world_size
-
1
)
assert
(
ps
.
get_data_parallel_rank
()
==
0
)
ps
.
destroy_model_parallel
()
def
test_data_parallel_initializations
():
ps
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
assert
(
ps
.
get_data_parallel_src_rank
()
==
rank
)
assert
(
ps
.
get_data_parallel_world_size
()
==
world_size
-
1
)
assert
(
ps
.
get_data_parallel_rank
()
==
0
)
ps
.
destroy_model_parallel
()
def
test_tensor_model_parellel_world_size
():
ps
.
set_
tensor_model_parallel_
world_
size
(
world_size
)
ps
.
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
assert
(
ps
.
get_tensor_model_parallel_world_size
()
==
world_size
)
ps
.
set_tensor_model_parallel_world_size
(
None
)
assert
(
ps
.
get_tensor_model_parallel_world_size
()
==
world_size
)
ps
.
destroy_model_parallel
()
def
test_pipeline_model_parallel_world_size
():
ps
.
set_
pipeline_model_parallel_
world_
size
(
world_size
)
ps
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
assert
(
ps
.
get_pipeline_model_parallel_world_size
()
==
world_size
)
ps
.
set_pipeline_model_parallel_world_size
(
None
)
assert
(
ps
.
get_pipeline_model_parallel_world_size
()
==
world_size
)
ps
.
destroy_model_parallel
()
def
test_tensor_model_parallel_rank
():
ps
.
set_
tensor_model_parallel_
rank
(
rank
)
ps
.
initialize_model_parallel
(
tensor_model_parallel_
size
=
world_size
)
assert
(
ps
.
get_tensor_model_parallel_rank
()
==
rank
)
ps
.
set_tensor_model_parallel_rank
(
None
)
assert
(
ps
.
get_tensor_model_parallel_rank
()
==
rank
)
ps
.
destroy_model_parallel
()
def
test_
tensor
_model_parallel_rank
():
ps
.
set_
pipeline_model_parallel_
rank
(
rank
)
def
test_
pipeline
_model_parallel_rank
():
ps
.
initialize_model_parallel
(
pipeline_model_parallel_
size
=
world_size
)
assert
(
ps
.
get_pipeline_model_parallel_rank
()
==
rank
)
ps
.
set_pipeline_model_parallel_rank
(
None
)
assert
(
ps
.
get_pipeline_model_parallel_rank
()
==
rank
)
ps
.
destroy_model_parallel
()
def
test_is_pipeline_first_stage
():
assert
(
ps
.
is_pipeline_first_stage
(
ignore_virtual
=
True
))
assert
(
ps
.
is_pipeline_first_stage
())
ps
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
assert
(
ps
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
==
(
rank
==
0
))
assert
(
ps
.
is_pipeline_first_stage
()
==
(
rank
==
0
))
ps
.
destroy_model_parallel
()
def
test_is_pipeline_last_stage
():
assert
(
ps
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
==
(
ps
.
get_pipeline_model_parallel_rank
()
==
world_size
-
1
)
)
assert
(
ps
.
is_pipeline_last_stage
()
==
(
ps
.
get_pipeline_model_parallel_rank
()
==
world_size
-
1
)
)
ps
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
assert
(
ps
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
==
(
rank
==
world_size
-
1
))
assert
(
ps
.
is_pipeline_last_stage
()
==
(
rank
==
world_size
-
1
))
ps
.
destroy_model_parallel
()
def
test_virtual_pipeline_model_parallel_rank
():
ps
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
ps
.
set_virtual_pipeline_model_parallel_rank
(
rank
)
assert
(
ps
.
get_virtual_pipeline_model_parallel_rank
()
==
rank
)
ps
.
destroy_model_parallel
()
def
test_get_tensor_model_parallel_src_rank
():
ps
.
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
assert
(
ps
.
get_tensor_model_parallel_src_rank
()
==
((
rank
//
world_size
)
*
world_size
))
ps
.
destroy_model_parallel
()
def
test_global_memory_buffer
():
ps
.
_GLOBAL_MEMORY_BUFFER
=
None
ps
.
_set_global_memory_buffer
()
assert
(
ps
.
get_global_memory_buffer
()
is
not
None
)
"""
def test_get_virtual_pipeline_model_parallel_world_size():
ps.initialize_model_parallel(pipeline_model_parallel_size=world_size)
ps.set_virtual_pipeline_model_parallel_rank(world_size)
assert(ps.get_virtual_pipeline_model_parallel_world_size() == world_size)
ps.destroy_model_parallel()
def test_is_rank_in_embedding_group():
assert(ps.is_rank_in_embedding_group(ignore_virtual=True) == (rank in ps._EMBEDDING_GLOBAL_RANKS))
...
...
@@ -114,20 +158,7 @@ def test_is_pipeline_stage_at_split():
(ps.is_pipeline_stage_before_split(rank) and ps.is_pipeline_stage_after_split(rank+1))
)
def
test_virtual_pipeline_model_parallel_rank
():
ps
.
set_virtual_pipeline_model_parallel_rank
(
rank
)
assert
(
ps
.
get_virtual_pipeline_model_parallel_rank
()
==
rank
)
def
test_virtual_pipeline_model_parallel_rank
():
ps
.
set_virtual_pipeline_model_parallel_rank
(
rank
)
assert
(
ps
.
get_virtual_pipeline_model_parallel_rank
()
==
rank
)
def
test_get_virtual_pipeline_model_parallel_world_size
():
assert
(
ps
.
get_virtual_pipeline_model_parallel_world_size
()
==
virtual_pipeline_model_parallel_size
)
def
test_get_tensor_model_parallel_src_rank
():
assert
(
ps
.
get_tensor_model_parallel_src_rank
()
==
((
rank
//
world_size
)
*
world_size
))
def
global_memory_buffer
():
ps
.
_set_global_memory_buffer
()
assert
(
ps
.
get_global_memory_buffer
()
is
not
None
)
\ No newline at end of file
def test_destroy_model_parallel():
ps.destroy_model_parallel()
assert(ps._MODEL_PARALLEL_GROUP is None)
"""
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment