Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8b94a160
Commit
8b94a160
authored
Oct 13, 2022
by
Shanmugam Ramasamy
Committed by
Jared Casper
Oct 13, 2022
Browse files
Adding proper test cases
parent
f8614670
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
330 additions
and
108 deletions
+330
-108
.coverage
.coverage
+0
-0
.coveragerc
.coveragerc
+4
-1
.gitignore
.gitignore
+1
-0
.gitlab-ci.yml
.gitlab-ci.yml
+1
-2
megatron/core/tensor_parallel/random.py
megatron/core/tensor_parallel/random.py
+2
-0
tests/__init__.py
tests/__init__.py
+0
-0
tests/tensor_parallel/test_cross_entropy.py
tests/tensor_parallel/test_cross_entropy.py
+14
-0
tests/tensor_parallel/test_data.py
tests/tensor_parallel/test_data.py
+21
-0
tests/tensor_parallel/test_mappings.py
tests/tensor_parallel/test_mappings.py
+135
-0
tests/tensor_parallel/test_random.py
tests/tensor_parallel/test_random.py
+44
-0
tests/tensor_parallel/test_tensor_parallel_utils.py
tests/tensor_parallel/test_tensor_parallel_utils.py
+36
-0
tests/test_parallel_state.py
tests/test_parallel_state.py
+42
-105
tests/test_utilities.py
tests/test_utilities.py
+30
-0
No files found.
.coverage
deleted
100644 → 0
View file @
f8614670
File deleted
.coveragerc
View file @
8b94a160
[html]
directory = coverage
\ No newline at end of file
directory = coverage
[run]
data_file = .coverage_$LOCAL_RANK
.gitignore
View file @
8b94a160
__pycache__
*.so
build
.coverage_*
*.egg-info
.gitlab-ci.yml
View file @
8b94a160
...
...
@@ -4,8 +4,7 @@ test:
tags
:
-
docker_gpu_enabled
script
:
-
nvidia-smi
-
torchrun --nproc_per_node=2 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
-
torchrun --nproc_per_node=8 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
coverage
:
'
/(?i)total.*?
(100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
artifacts
:
paths
:
...
...
megatron/core/tensor_parallel/random.py
View file @
8b94a160
...
...
@@ -22,6 +22,8 @@ from .utils import (
gather_split_1d_tensor
,
)
from
megatron.core.utils
import
safely_set_viewless_tensor_data
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
...
...
tests/__init__.py
0 → 100644
View file @
8b94a160
tests/tensor_parallel/test_cross_entropy.py
0 → 100644
View file @
8b94a160
from
megatron.core.tensor_parallel.cross_entropy
import
vocab_parallel_cross_entropy
import
torch
from
tests.test_utilities
import
Utils
import
numpy
as
np
def
test_vocab_parallel_cross_entropy
():
Utils
.
initialize_model_parallel
(
4
,
2
)
vocab_parallel_logits
=
torch
.
range
(
0
,
7
).
repeat
(
16
,
4
).
cuda
()
target
=
torch
.
arange
(
0
,
32
,
2
).
cuda
()
output
=
vocab_parallel_cross_entropy
(
vocab_parallel_logits
,
target
)
expected_output
=
torch
.
tensor
([
10.2309
,
8.2309
,
6.2309
,
4.2309
,
10.2309
,
8.2309
,
6.2309
,
4.2309
,
10.2309
,
8.2309
,
6.2309
,
4.2309
,
10.2309
,
8.2309
,
6.2309
,
4.2309
]).
cuda
()
assert
(
torch
.
equal
(
torch
.
round
(
expected_output
),
torch
.
round
(
output
)))
Utils
.
destroy_model_parallel
()
\ No newline at end of file
tests/tensor_parallel/test_data.py
0 → 100644
View file @
8b94a160
from
megatron.core.tensor_parallel.data
import
broadcast_data
import
torch
from
tests.test_utilities
import
Utils
def
test_broadcast_data
():
Utils
.
initialize_model_parallel
(
2
,
4
)
input_data
=
{
0
:
torch
.
ones
((
8
,
8
)).
cuda
()
*
0.0
,
1
:
torch
.
ones
((
8
,
8
)).
cuda
()
*
1.0
,
2
:
torch
.
ones
((
8
,
8
)).
cuda
()
*
2.0
,
3
:
torch
.
ones
((
8
,
8
)).
cuda
()
*
3.0
,
4
:
torch
.
ones
((
8
,
8
)).
cuda
()
*
4.0
,
5
:
torch
.
ones
((
8
,
8
)).
cuda
()
*
5.0
,
6
:
torch
.
ones
((
8
,
8
)).
cuda
()
*
6.0
,
7
:
torch
.
ones
((
8
,
8
)).
cuda
()
*
7.0
}
dtype
=
torch
.
float32
actual_output
=
broadcast_data
([
0
,
1
],
input_data
,
dtype
)
assert
(
torch
.
equal
(
actual_output
[
0
],
input_data
[
0
]))
assert
(
torch
.
equal
(
actual_output
[
1
],
input_data
[
1
]))
Utils
.
destroy_model_parallel
()
\ No newline at end of file
tests/tensor_parallel/test_mappings.py
0 → 100644
View file @
8b94a160
from
megatron.core.tensor_parallel
import
mappings
from
tests.test_utilities
import
Utils
import
torch
def
test_CopyToModelParallelRegion
():
Utils
.
initialize_model_parallel
(
4
,
2
)
input_data
=
torch
.
ones
((
1
)).
cuda
()
*
Utils
.
rank
output_data
=
mappings
.
_CopyToModelParallelRegion
.
backward
(
None
,
input_data
)
result
=
torch
.
ones
(
1
).
cuda
()
result
=
result
*
22
if
Utils
.
rank
>=
4
else
result
*
6
assert
(
torch
.
equal
(
output_data
,
result
))
assert
(
torch
.
equal
(
input_data
,
mappings
.
copy_to_tensor_model_parallel_region
(
input_data
)))
assert
(
torch
.
equal
(
input_data
,
mappings
.
_CopyToModelParallelRegion
.
symbolic
(
None
,
input_data
)))
Utils
.
destroy_model_parallel
()
def
test_ReduceFromModelParallelRegion
():
Utils
.
initialize_model_parallel
(
4
,
2
)
input_data
=
torch
.
ones
((
1
)).
cuda
()
*
Utils
.
rank
output_data
=
mappings
.
_ReduceFromModelParallelRegion
.
symbolic
(
None
,
input_data
)
result
=
torch
.
ones
(
1
).
cuda
()
result
=
result
*
22
if
Utils
.
rank
>=
4
else
result
*
6
assert
(
torch
.
equal
(
output_data
,
result
))
input_data
=
torch
.
ones
((
1
)).
cuda
()
*
Utils
.
rank
assert
(
torch
.
equal
(
mappings
.
reduce_from_tensor_model_parallel_region
(
input_data
),
result
))
assert
(
torch
.
equal
(
input_data
,
mappings
.
_ReduceFromModelParallelRegion
.
backward
(
None
,
input_data
)))
Utils
.
destroy_model_parallel
()
def
test_ScatterToModelParallelRegion
():
Utils
.
initialize_model_parallel
(
4
,
2
)
input_data
=
torch
.
rand
((
8
,
4
)).
cuda
()
output_data
=
mappings
.
scatter_to_tensor_model_parallel_region
(
input_data
)
req_dim
=
int
(
Utils
.
rank
%
(
Utils
.
world_size
/
2
))
assert
(
torch
.
equal
(
output_data
,
input_data
[:,
req_dim
].
reshape
((
8
,
1
))))
output_data
=
mappings
.
_ScatterToModelParallelRegion
.
symbolic
(
None
,
input_data
)
assert
(
torch
.
equal
(
output_data
,
input_data
[:,
req_dim
].
reshape
((
8
,
1
))))
input_data
=
torch
.
ones
(
8
).
cuda
()
*
Utils
.
rank
actual_output_data
=
mappings
.
_ScatterToModelParallelRegion
.
backward
(
None
,
input_data
)
expected_output
=
torch
.
cat
((
torch
.
ones
(
8
)
*
0
,
torch
.
ones
(
8
)
*
1
,
torch
.
ones
(
8
)
*
2
,
torch
.
ones
(
8
)
*
3
)).
cuda
()
if
(
Utils
.
rank
>=
4
):
expected_output
=
expected_output
+
4
assert
(
torch
.
equal
(
actual_output_data
,
expected_output
))
Utils
.
destroy_model_parallel
()
def
test_GatherFromModelParallelRegion
():
Utils
.
initialize_model_parallel
(
4
,
2
)
input_data
=
torch
.
rand
((
8
,
4
)).
cuda
()
req_dim
=
int
(
Utils
.
rank
%
(
Utils
.
world_size
/
2
))
output_data
=
mappings
.
_GatherFromModelParallelRegion
.
backward
(
None
,
input_data
)
assert
(
torch
.
equal
(
output_data
,
input_data
[:,
req_dim
].
reshape
((
8
,
1
))))
input_data
=
torch
.
ones
(
8
).
cuda
()
*
Utils
.
rank
actual_output_data
=
mappings
.
gather_from_tensor_model_parallel_region
(
input_data
)
expected_output
=
torch
.
cat
((
torch
.
ones
(
8
)
*
0
,
torch
.
ones
(
8
)
*
1
,
torch
.
ones
(
8
)
*
2
,
torch
.
ones
(
8
)
*
3
)).
cuda
()
if
(
Utils
.
rank
>=
4
):
expected_output
=
expected_output
+
4
assert
(
torch
.
equal
(
actual_output_data
,
expected_output
))
assert
(
torch
.
equal
(
mappings
.
_GatherFromModelParallelRegion
.
symbolic
(
None
,
input_data
),
expected_output
))
Utils
.
destroy_model_parallel
()
def
test_ScatterToSequenceParallelRegion
():
Utils
.
initialize_model_parallel
(
4
,
2
)
input_data
=
torch
.
rand
((
8
,
4
)).
cuda
()
req_dim
=
int
(
Utils
.
rank
%
(
Utils
.
world_size
/
2
))
*
2
output_data
=
mappings
.
_ScatterToSequenceParallelRegion
.
symbolic
(
None
,
input_data
)
assert
(
torch
.
equal
(
output_data
,
input_data
[
req_dim
:
req_dim
+
2
,
:]))
output_data
=
mappings
.
scatter_to_sequence_parallel_region
(
input_data
)
assert
(
torch
.
equal
(
output_data
,
input_data
[
req_dim
:
req_dim
+
2
,
:]))
input_data
=
torch
.
ones
(
4
).
cuda
()
*
Utils
.
rank
output_data
=
mappings
.
_ScatterToModelParallelRegion
.
backward
(
None
,
input_data
)
expected_output
=
torch
.
concat
((
torch
.
ones
(
4
)
*
0
,
torch
.
ones
(
4
)
*
1
,
torch
.
ones
(
4
)
*
2
,
torch
.
ones
(
4
)
*
3
)).
cuda
()
if
(
Utils
.
rank
>=
4
):
expected_output
=
expected_output
+
4
assert
(
torch
.
equal
(
output_data
,
expected_output
))
Utils
.
destroy_model_parallel
()
def
test_GatherFromSequenceParallelRegion
():
Utils
.
initialize_model_parallel
(
4
,
2
)
input_data
=
torch
.
ones
(
4
).
cuda
()
*
Utils
.
rank
output_data
=
mappings
.
gather_from_sequence_parallel_region
(
input_data
)
expected_output
=
torch
.
concat
((
torch
.
ones
(
4
)
*
0
,
torch
.
ones
(
4
)
*
1
,
torch
.
ones
(
4
)
*
2
,
torch
.
ones
(
4
)
*
3
)).
cuda
()
if
(
Utils
.
rank
>=
4
):
expected_output
=
expected_output
+
4
assert
(
torch
.
equal
(
output_data
,
expected_output
))
assert
(
torch
.
equal
(
mappings
.
_GatherFromSequenceParallelRegion
.
symbolic
(
None
,
input_data
),
expected_output
))
input_data
=
torch
.
vstack
((
torch
.
ones
(
4
)
*
0
,
torch
.
ones
(
4
)
*
1
,
torch
.
ones
(
4
)
*
2
,
torch
.
ones
(
4
)
*
3
)).
cuda
()
class
Ctx
:
tensor_parallel_output_grad
=
True
output_data
=
mappings
.
_GatherFromSequenceParallelRegion
.
backward
(
Ctx
(),
input_data
)
expected_output
=
torch
.
ones
((
1
,
4
)).
cuda
()
*
4
*
int
(
Utils
.
rank
%
4
)
assert
(
torch
.
equal
(
output_data
[
0
],
expected_output
))
Utils
.
destroy_model_parallel
()
def
test_ReduceScatterToSequenceParallelRegion
():
Utils
.
initialize_model_parallel
(
4
,
2
)
input_data
=
torch
.
vstack
((
torch
.
ones
(
4
)
*
0
,
torch
.
ones
(
4
)
*
1
,
torch
.
ones
(
4
)
*
2
,
torch
.
ones
(
4
)
*
3
)).
cuda
()
output_data
=
mappings
.
reduce_scatter_to_sequence_parallel_region
(
input_data
)
expected_output
=
torch
.
ones
(
4
).
cuda
()
*
4
*
int
(
Utils
.
rank
%
4
)
assert
(
torch
.
equal
(
output_data
[
0
],
expected_output
))
assert
(
torch
.
equal
(
mappings
.
_ReduceScatterToSequenceParallelRegion
.
symbolic
(
None
,
input_data
)
,
expected_output
.
reshape
((
1
,
4
))))
input_data
=
torch
.
ones
(
4
).
cuda
()
*
Utils
.
rank
output_data
=
mappings
.
_ReduceScatterToSequenceParallelRegion
.
backward
(
None
,
input_data
)
expected_output
=
torch
.
concat
((
torch
.
ones
(
4
)
*
0
,
torch
.
ones
(
4
)
*
1
,
torch
.
ones
(
4
)
*
2
,
torch
.
ones
(
4
)
*
3
)).
cuda
()
if
(
Utils
.
rank
>=
4
):
expected_output
=
expected_output
+
4
assert
(
torch
.
equal
(
output_data
,
expected_output
))
Utils
.
destroy_model_parallel
()
tests/tensor_parallel/test_random.py
0 → 100644
View file @
8b94a160
from
megatron.core.tensor_parallel.random
import
CudaRNGStatesTracker
from
megatron.core.tensor_parallel.random
import
model_parallel_cuda_manual_seed
from
megatron.core.tensor_parallel.random
import
_CUDA_RNG_STATE_TRACKER
from
megatron.core.tensor_parallel.random
import
checkpoint
from
tests.test_utilities
import
Utils
import
pytest
import
torch
def
test_cuda_rng_states_tracker
():
rng_tracker
=
CudaRNGStatesTracker
()
rng_tracker
.
set_states
({
"state1"
:
1234
})
assert
(
rng_tracker
.
get_states
()[
"state1"
]
==
1234
)
rng_tracker
.
reset
()
assert
(
rng_tracker
.
get_states
()
==
{})
seed
=
1111
rng_tracker
.
add
(
"state2"
,
seed
)
with
pytest
.
raises
(
Exception
):
assert
(
rng_tracker
.
add
(
"state3"
,
seed
))
with
pytest
.
raises
(
Exception
):
assert
(
rng_tracker
.
add
(
"state2"
,
111
))
assert
(
rng_tracker
.
get_states
()[
'state2'
]
is
not
None
)
with
pytest
.
raises
(
Exception
):
assert
()
rng_tracker
.
fork
(
"state2"
)
torch
.
cuda
.
manual_seed
(
seed
)
rng_state
=
torch
.
cuda
.
get_rng_state
()
assert
torch
.
equal
(
rng_tracker
.
get_states
()[
'state2'
],
rng_state
)
def
test_model_parallel_cuda_manual_seed
():
Utils
.
initialize_model_parallel
(
4
,
2
)
model_parallel_cuda_manual_seed
(
0
)
assert
(
_CUDA_RNG_STATE_TRACKER
.
get_states
()[
'model-parallel-rng'
]
is
not
None
)
Utils
.
destroy_model_parallel
()
def
test_checkpoint
():
def
test_forward
(
*
input
):
return
input
[
0
]
+
input
[
1
]
assert
(
torch
.
equal
(
torch
.
ones
(
16
)
*
3
,
checkpoint
(
test_forward
,
None
,
torch
.
ones
(
16
),
torch
.
ones
(
16
)
*
2
)))
Utils
.
initialize_model_parallel
()
input1
=
torch
.
ones
((
4
,
4
))
checkpoint
(
test_forward
,
True
,
input1
,
torch
.
ones
((
4
,
4
))
*
2
)
assert
(
torch
.
equal
(
torch
.
ones
(
input1
.
numel
()).
cuda
(),
input1
))
Utils
.
destroy_model_parallel
()
\ No newline at end of file
tests/tensor_parallel/test_tensor_parallel_utils.py
View file @
8b94a160
import
torch
import
megatron.core.tensor_parallel.utils
as
util
import
megatron.core.parallel_state
as
ps
from
tests.test_utilities
import
Utils
rank
=
Utils
.
rank
def
test_split_tensor_along_last_dim
():
input_tensor
=
torch
.
rand
((
3
,
4
))
torch
.
equal
(
input_tensor
[
0
:
2
,
0
:
2
],
util
.
split_tensor_along_last_dim
(
input_tensor
,
2
)[
0
])
torch
.
equal
(
input_tensor
[
2
:,
2
:],
util
.
split_tensor_along_last_dim
(
input_tensor
,
2
)[
1
])
def
test_split_tensor_into_1d_equal_chunks
():
Utils
.
initialize_model_parallel
(
tensor_model_parallel_size
=
2
,
pipeline_model_parallel_size
=
4
)
input_tensor
=
torch
.
rand
((
3
,
4
))
output_tensor
=
util
.
split_tensor_into_1d_equal_chunks
(
input_tensor
)
if
rank
%
2
==
0
:
start
=
0
end
=
int
(
input_tensor
.
numel
()
/
2
)
else
:
start
=
int
(
input_tensor
.
numel
()
/
2
)
end
=
input_tensor
.
numel
()
assert
torch
.
equal
(
output_tensor
,
input_tensor
.
flatten
()[
start
:
end
])
Utils
.
destroy_model_parallel
()
def
test_gather_split_1d_tensor
():
Utils
.
initialize_model_parallel
(
tensor_model_parallel_size
=
2
,
pipeline_model_parallel_size
=
4
)
input_tensor
=
torch
.
ones
((
2
,
4
)).
cuda
()
*
rank
actual_output_tensor
=
util
.
gather_split_1d_tensor
(
input_tensor
)
if
rank
%
2
==
0
:
expected_output_tensor
=
torch
.
concat
((
input_tensor
.
flatten
(),
input_tensor
.
flatten
()
+
1
))
else
:
expected_output_tensor
=
torch
.
concat
((
input_tensor
.
flatten
()
-
1
,
input_tensor
.
flatten
()))
assert
(
torch
.
equal
(
actual_output_tensor
,
expected_output_tensor
))
Utils
.
destroy_model_parallel
()
def
test_vocab
():
global_vocab_size
=
1600
per_partition_vocab_size
=
1600
/
Utils
.
world_size
assert
((
rank
*
per_partition_vocab_size
,
(
rank
+
1
)
*
per_partition_vocab_size
)
==
(
util
.
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
global_vocab_size
//
Utils
.
world_size
,
rank
,
Utils
.
world_size
)))
assert
((
rank
*
per_partition_vocab_size
,
(
rank
+
1
)
*
per_partition_vocab_size
)
==
(
util
.
VocabUtility
.
vocab_range_from_global_vocab_size
(
global_vocab_size
,
rank
,
Utils
.
world_size
)))
\ No newline at end of file
tests/test_parallel_state.py
View file @
8b94a160
import
os
import
torch
import
megatron.core.parallel_state
as
ps
from
datetime
import
timedelta
import
pytest
from
tests.test_utilities
import
Utils
import
os
rank
=
Utils
.
rank
world_size
=
Utils
.
world_size
world_size
=
torch
.
cuda
.
device_count
()
rank
=
int
(
os
.
environ
[
'LOCAL_RANK'
])
print
(
'Ranks is : '
+
str
(
rank
))
def
initialize_distributed
():
print
(
f
'Initializing torch.distributed with rank:
{
rank
}
, world_size:
{
world_size
}
'
)
torch
.
cuda
.
set_device
(
rank
%
torch
.
cuda
.
device_count
())
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
world_size
=
world_size
,
rank
=
rank
,
init_method
=
init_method
,
timeout
=
timedelta
(
seconds
=
10
))
def
initialize_model_parallel
(
tensor_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
virtual_pipeline_model_parallel_size
=
None
,
pipeline_model_parallel_split_rank
=
None
,
):
# This might not be the right way to do this.
try
:
ps
.
initialize_model_parallel
(
tensor_model_parallel_size
,
pipeline_model_parallel_size
,
virtual_pipeline_model_parallel_size
,
pipeline_model_parallel_split_rank
)
except
:
ps
.
destroy_model_parallel
()
ps
.
initialize_model_parallel
(
tensor_model_parallel_size
,
pipeline_model_parallel_size
,
virtual_pipeline_model_parallel_size
,
pipeline_model_parallel_split_rank
)
pass
def
test_initialize_model_parallel
():
def
test_initialize__and_destroy_model_parallel
():
with
pytest
.
raises
(
AssertionError
):
assert
(
ps
.
initialize_model_parallel
())
initialize_distributed
()
Utils
.
initialize_distributed
()
with
pytest
.
raises
(
RuntimeError
):
assert
(
ps
.
initialize_model_parallel
(
tensor_model_parallel_size
=
2
*
world_size
))
with
pytest
.
raises
(
RuntimeError
):
...
...
@@ -44,124 +19,86 @@ def test_initialize_model_parallel():
assert
(
ps
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
,
tensor_model_parallel_size
=
world_size
))
with
pytest
.
raises
(
RuntimeError
):
assert
(
ps
.
initialize_model_parallel
(
virtual_pipeline_model_parallel_size
=
2
))
initialize_model_parallel
()
Utils
.
initialize_model_parallel
(
tensor_model_parallel_size
=
2
,
pipeline_model_parallel_size
=
4
)
assert
(
ps
.
model_parallel_is_initialized
())
assert
(
ps
.
get_model_parallel_group
()
is
not
None
)
assert
(
ps
.
get_tensor_model_parallel_group
()
is
not
None
)
assert
(
ps
.
get_pipeline_model_parallel_group
()
is
not
None
)
assert
(
ps
.
get_data_parallel_group
()
is
not
None
)
assert
(
ps
.
get_embedding_group
()
is
not
None
)
assert
(
ps
.
get_position_embedding_group
()
is
not
None
)
ps
.
destroy_model_parallel
()
Utils
.
destroy_model_parallel
()
assert
(
ps
.
_MODEL_PARALLEL_GROUP
is
None
)
def
test_pipeline_parallel_initializations
():
initialize_model_parallel
(
pipeline_model_parallel_size
=
2
)
assert
(
ps
.
get_pipeline_model_parallel_first_rank
()
==
0
)
Utils
.
initialize_model_parallel
(
tensor_model_parallel_size
=
2
,
pipeline_model_parallel_size
=
4
)
assert
(
ps
.
get_pipeline_model_parallel_first_rank
()
==
rank
%
2
)
assert
(
ps
.
get_data_parallel_src_rank
()
==
rank
)
assert
(
ps
.
get_pipeline_model_parallel_next_rank
()
==
0
if
rank
==
world_size
-
1
else
rank
+
1
)
assert
(
ps
.
get_pipeline_model_parallel_prev_rank
()
==
rank
-
1
if
rank
>
0
else
world_size
-
1
)
p
s
.
destroy_model_parallel
()
assert
(
ps
.
get_pipeline_model_parallel_next_rank
()
==
((
rank
+
2
)
%
world_size
)
)
assert
(
ps
.
get_pipeline_model_parallel_prev_rank
()
==
((
rank
-
2
)
%
world_size
)
)
Util
s
.
destroy_model_parallel
()
def
test_data_parallel_initializations
():
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
Utils
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
assert
(
ps
.
get_data_parallel_src_rank
()
==
rank
)
assert
(
ps
.
get_data_parallel_world_size
()
==
world_size
-
1
)
assert
(
ps
.
get_data_parallel_world_size
()
==
1
)
assert
(
ps
.
get_data_parallel_rank
()
==
0
)
p
s
.
destroy_model_parallel
()
Util
s
.
destroy_model_parallel
()
def
test_tensor_model_parellel_world_size
():
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
Utils
.
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
assert
(
ps
.
get_tensor_model_parallel_world_size
()
==
world_size
)
ps
.
set_tensor_model_parallel_world_size
(
None
)
assert
(
ps
.
get_tensor_model_parallel_world_size
()
==
world_size
)
p
s
.
destroy_model_parallel
()
Util
s
.
destroy_model_parallel
()
def
test_pipeline_model_parallel_world_size
():
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
Utils
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
assert
(
ps
.
get_pipeline_model_parallel_world_size
()
==
world_size
)
ps
.
set_pipeline_model_parallel_world_size
(
None
)
assert
(
ps
.
get_pipeline_model_parallel_world_size
()
==
world_size
)
p
s
.
destroy_model_parallel
()
Util
s
.
destroy_model_parallel
()
def
test_tensor_model_parallel_rank
():
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
Utils
.
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
assert
(
ps
.
get_tensor_model_parallel_rank
()
==
rank
)
ps
.
set_tensor_model_parallel_rank
(
None
)
assert
(
ps
.
get_tensor_model_parallel_rank
()
==
rank
)
ps
.
destroy_model_parallel
()
Utils
.
destroy_model_parallel
()
def
test_pipeline_model_parallel_rank
():
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
Utils
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
assert
(
ps
.
get_pipeline_model_parallel_rank
()
==
rank
)
ps
.
set_pipeline_model_parallel_rank
(
None
)
assert
(
ps
.
get_pipeline_model_parallel_rank
()
==
rank
)
p
s
.
destroy_model_parallel
()
Util
s
.
destroy_model_parallel
()
def
test_is_pipeline_first_stage
():
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
Utils
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
assert
(
ps
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
==
(
rank
==
0
))
assert
(
ps
.
is_pipeline_first_stage
()
==
(
rank
==
0
))
ps
.
destroy_model_parallel
()
Utils
.
destroy_model_parallel
()
def
test_is_pipeline_last_stage
():
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
Utils
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
assert
(
ps
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
==
(
rank
==
world_size
-
1
))
assert
(
ps
.
is_pipeline_last_stage
()
==
(
rank
==
world_size
-
1
))
p
s
.
destroy_model_parallel
()
Util
s
.
destroy_model_parallel
()
def
test_virtual_pipeline_model_parallel_rank
():
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
Utils
.
initialize_model_parallel
(
pipeline_model_parallel_size
=
world_size
)
ps
.
set_virtual_pipeline_model_parallel_rank
(
rank
)
assert
(
ps
.
get_virtual_pipeline_model_parallel_rank
()
==
rank
)
ps
.
destroy_model_parallel
()
Utils
.
destroy_model_parallel
()
def
test_get_tensor_model_parallel_src_rank
():
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
Utils
.
initialize_model_parallel
(
tensor_model_parallel_size
=
world_size
)
assert
(
ps
.
get_tensor_model_parallel_src_rank
()
==
((
rank
//
world_size
)
*
world_size
))
ps
.
destroy_model_parallel
()
"""
def test_get_virtual_pipeline_model_parallel_world_size():
initialize_model_parallel(pipeline_model_parallel_size=world_size)
ps.set_virtual_pipeline_model_parallel_rank(world_size)
assert(ps.get_virtual_pipeline_model_parallel_world_size() == world_size)
ps.destroy_model_parallel()
def test_is_rank_in_embedding_group():
assert(ps.is_rank_in_embedding_group(ignore_virtual=True) == (rank in ps._EMBEDDING_GLOBAL_RANKS))
if rank in ps._EMBEDDING_GLOBAL_RANKS:
assert(ps.is_rank_in_embedding_group() == ps.is_pipeline_first_stage())
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
assert(ps.is_rank_in_embedding_group() == ps.is_pipeline_last_stage())
else:
assert(ps.is_rank_in_embedding_group())
def test_is_rank_in_position_embedding_group():
assert(ps.is_rank_in_position_embedding_group() == (rank in ps._POSITION_EMBEDDING_GLOBAL_RANKS))
def test_is_pipeline_stage_before_split():
if world_size == 1:
assert(ps.is_pipeline_stage_before_split())
# TODO: Changes here for more than one world size
assert(ps.is_pipeline_stage_before_split())
def test_is_pipeline_stage_after_split():
if world_size == 1:
assert(ps.is_pipeline_stage_after_split())
# TODO: Changes here for more than one world size
assert(ps.is_pipeline_stage_before_split())
def test_is_pipeline_stage_at_split():
assert(
ps.is_pipeline_stage_at_split() ==
(ps.is_pipeline_stage_before_split(rank) and ps.is_pipeline_stage_after_split(rank+1))
)
def test_destroy_model_parallel():
ps.destroy_model_parallel()
assert(ps._MODEL_PARALLEL_GROUP is None)
"""
\ No newline at end of file
Utils
.
destroy_model_parallel
()
\ No newline at end of file
tests/test_utilities.py
0 → 100644
View file @
8b94a160
import
os
import
torch
import
megatron.core.parallel_state
as
ps
class
Utils
:
world_size
=
torch
.
cuda
.
device_count
()
rank
=
int
(
os
.
environ
[
'LOCAL_RANK'
])
@
staticmethod
def
initialize_distributed
():
print
(
f
'Initializing torch.distributed with rank:
{
Utils
.
rank
}
, world_size:
{
Utils
.
world_size
}
'
)
torch
.
cuda
.
set_device
(
Utils
.
rank
%
torch
.
cuda
.
device_count
())
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
world_size
=
Utils
.
world_size
,
rank
=
Utils
.
rank
,
init_method
=
init_method
)
@
staticmethod
def
destroy_model_parallel
():
ps
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
@
staticmethod
def
initialize_model_parallel
(
tensor_model_parallel_size
=
1
,
pipeline_model_parallel_size
=
1
,
virtual_pipeline_model_parallel_size
=
None
,
pipeline_model_parallel_split_rank
=
None
):
ps
.
destroy_model_parallel
()
if
not
torch
.
distributed
.
is_initialized
():
Utils
.
initialize_distributed
()
ps
.
initialize_model_parallel
(
tensor_model_parallel_size
,
pipeline_model_parallel_size
,
virtual_pipeline_model_parallel_size
,
pipeline_model_parallel_split_rank
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment