Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
52a5f2f2
Commit
52a5f2f2
authored
Oct 20, 2020
by
Deepak Narayanan
Browse files
Intra-layer MP -> Tensor MP, Inter-layer MP -> Pipeline MP
parent
7abd3e90
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
403 additions
and
403 deletions
+403
-403
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+14
-14
megatron/mpu/cross_entropy.py
megatron/mpu/cross_entropy.py
+8
-8
megatron/mpu/data.py
megatron/mpu/data.py
+9
-9
megatron/mpu/grads.py
megatron/mpu/grads.py
+6
-6
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+116
-116
megatron/mpu/layers.py
megatron/mpu/layers.py
+22
-22
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+12
-12
megatron/mpu/random.py
megatron/mpu/random.py
+15
-15
megatron/mpu/tests/commons.py
megatron/mpu/tests/commons.py
+1
-1
megatron/mpu/tests/test_cross_entropy.py
megatron/mpu/tests/test_cross_entropy.py
+11
-11
megatron/mpu/tests/test_data.py
megatron/mpu/tests/test_data.py
+10
-10
megatron/mpu/tests/test_initialize.py
megatron/mpu/tests/test_initialize.py
+23
-23
megatron/mpu/tests/test_layers.py
megatron/mpu/tests/test_layers.py
+75
-75
megatron/mpu/tests/test_random.py
megatron/mpu/tests/test_random.py
+26
-26
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+19
-19
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+1
-1
megatron/training.py
megatron/training.py
+19
-19
megatron/utils.py
megatron/utils.py
+2
-2
pretrain_bert.py
pretrain_bert.py
+7
-7
pretrain_gpt2.py
pretrain_gpt2.py
+7
-7
No files found.
megatron/mpu/__init__.py
View file @
52a5f2f2
...
...
@@ -28,15 +28,15 @@ from .initialize import get_data_parallel_rank
from
.initialize
import
get_data_parallel_world_size
from
.initialize
import
get_embedding_group
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_
intra_laye
r_model_parallel_group
from
.initialize
import
get_
inter_layer
_model_parallel_group
from
.initialize
import
get_
intra_laye
r_model_parallel_rank
,
set_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
inter_layer
_model_parallel_rank
,
set_
inter_layer
_model_parallel_rank
from
.initialize
import
is_
inter_layer
_first_stage
,
is_
inter_layer
_last_stage
from
.initialize
import
get_
intra_laye
r_model_parallel_src_rank
from
.initialize
import
get_
inter_layer
_model_parallel_src_rank
from
.initialize
import
get_
intra_laye
r_model_parallel_world_size
,
set_
intra_laye
r_model_parallel_world_size
from
.initialize
import
get_
inter_layer
_model_parallel_world_size
,
set_
inter_layer
_model_parallel_world_size
from
.initialize
import
get_
tenso
r_model_parallel_group
from
.initialize
import
get_
pipeline
_model_parallel_group
from
.initialize
import
get_
tenso
r_model_parallel_rank
,
set_
tenso
r_model_parallel_rank
from
.initialize
import
get_
pipeline
_model_parallel_rank
,
set_
pipeline
_model_parallel_rank
from
.initialize
import
is_
pipeline
_first_stage
,
is_
pipeline
_last_stage
from
.initialize
import
get_
tenso
r_model_parallel_src_rank
from
.initialize
import
get_
pipeline
_model_parallel_src_rank
from
.initialize
import
get_
tenso
r_model_parallel_world_size
,
set_
tenso
r_model_parallel_world_size
from
.initialize
import
get_
pipeline
_model_parallel_world_size
,
set_
pipeline
_model_parallel_world_size
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
...
...
@@ -45,15 +45,15 @@ from .layers import ColumnParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
VocabParallelEmbedding
from
.mappings
import
copy_to_
intra_laye
r_model_parallel_region
from
.mappings
import
gather_from_
intra_laye
r_model_parallel_region
from
.mappings
import
reduce_from_
intra_laye
r_model_parallel_region
from
.mappings
import
scatter_to_
intra_laye
r_model_parallel_region
from
.mappings
import
copy_to_
tenso
r_model_parallel_region
from
.mappings
import
gather_from_
tenso
r_model_parallel_region
from
.mappings
import
reduce_from_
tenso
r_model_parallel_region
from
.mappings
import
scatter_to_
tenso
r_model_parallel_region
from
.random
import
checkpoint
from
.random
import
get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
intra_layer_
model_parallel_cuda_manual_seed
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.utils
import
divide
...
...
megatron/mpu/cross_entropy.py
View file @
52a5f2f2
...
...
@@ -16,9 +16,9 @@
import
torch
from
.initialize
import
get_
intra_laye
r_model_parallel_group
from
.initialize
import
get_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
intra_laye
r_model_parallel_world_size
from
.initialize
import
get_
tenso
r_model_parallel_group
from
.initialize
import
get_
tenso
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_world_size
from
.utils
import
VocabUtility
...
...
@@ -31,15 +31,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
logits_max
=
torch
.
max
(
vocab_parallel_logits
,
dim
=-
1
)[
0
]
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_
intra_laye
r_model_parallel_group
())
group
=
get_
tenso
r_model_parallel_group
())
# Subtract the maximum value.
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
# Get the partition's vocab indecies
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
rank
=
get_
intra_laye
r_model_parallel_rank
()
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
rank
=
get_
tenso
r_model_parallel_rank
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
vocab_start_index
,
vocab_end_index
=
get_vocab_range
(
partition_vocab_size
,
rank
,
world_size
)
...
...
@@ -62,7 +62,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# All reduce is needed to get the chunks from other GPUs.
torch
.
distributed
.
all_reduce
(
predicted_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_
intra_laye
r_model_parallel_group
())
group
=
get_
tenso
r_model_parallel_group
())
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
vocab_parallel_logits
...
...
@@ -70,7 +70,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_
intra_laye
r_model_parallel_group
())
group
=
get_
tenso
r_model_parallel_group
())
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
...
...
megatron/mpu/data.py
View file @
52a5f2f2
...
...
@@ -15,9 +15,9 @@
import
torch
from
.initialize
import
get_
intra_laye
r_model_parallel_group
from
.initialize
import
get_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
intra_laye
r_model_parallel_src_rank
from
.initialize
import
get_
tenso
r_model_parallel_group
from
.initialize
import
get_
tenso
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_src_rank
_MAX_DATA_DIM
=
4
...
...
@@ -36,7 +36,7 @@ def _build_key_size_numel_dictionaries(keys, data):
sizes
=
[
0
for
_
in
range
(
max_dim
)
for
_
in
keys
]
# Pack the sizes on rank zero.
if
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
get_
tenso
r_model_parallel_rank
()
==
0
:
offset
=
0
for
key
in
keys
:
assert
data
[
key
].
dim
()
<
max_dim
,
'you should increase MAX_DATA_DIM'
...
...
@@ -47,8 +47,8 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
sizes_cuda
=
torch
.
cuda
.
LongTensor
(
sizes
)
torch
.
distributed
.
broadcast
(
sizes_cuda
,
get_
intra_laye
r_model_parallel_src_rank
(),
group
=
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
broadcast
(
sizes_cuda
,
get_
tenso
r_model_parallel_src_rank
(),
group
=
get_
tenso
r_model_parallel_group
())
# Move back to cpu and unpack.
sizes_cpu
=
sizes_cuda
.
cpu
()
...
...
@@ -89,7 +89,7 @@ def broadcast_data(keys, data, datatype):
data
)
# Pack on rank zero.
if
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
get_
tenso
r_model_parallel_rank
()
==
0
:
# Check that all keys have the same data type.
_check_data_types
(
keys
,
data
,
datatype
)
# Flatten the data associated with the keys
...
...
@@ -101,8 +101,8 @@ def broadcast_data(keys, data, datatype):
dtype
=
datatype
)
# Broadcast
torch
.
distributed
.
broadcast
(
flatten_data
,
get_
intra_laye
r_model_parallel_src_rank
(),
group
=
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
broadcast
(
flatten_data
,
get_
tenso
r_model_parallel_src_rank
(),
group
=
get_
tenso
r_model_parallel_group
())
# Unpack
output
=
{}
...
...
megatron/mpu/grads.py
View file @
52a5f2f2
...
...
@@ -28,9 +28,9 @@ try:
except
Exception
as
e
:
print
(
'WARNING: APEX is not installed, multi_tensor_applier will not be available.'
)
from
.initialize
import
is_
inter_layer
_first_stage
from
.initialize
import
is_
pipeline
_first_stage
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_rank
def
l2_grad_clipper
(
parameters
,
max_norm
):
...
...
@@ -44,9 +44,9 @@ def l2_grad_clipper(parameters, max_norm):
parameters_with_grads
=
list
(
filter
(
lambda
p
:
p
.
grad
is
not
None
,
parameters
))
# Filter parameters for norm calculations.
mp_rank_is_zero
=
(
get_
intra_laye
r_model_parallel_rank
()
==
0
)
mp_rank_is_zero
=
(
get_
tenso
r_model_parallel_rank
()
==
0
)
parameters_for_norm
=
list
(
filter
(
lambda
p
:
p
.
intra_laye
r_model_parallel
or
mp_rank_is_zero
,
parameters_with_grads
))
lambda
p
:
p
.
tenso
r_model_parallel
or
mp_rank_is_zero
,
parameters_with_grads
))
# Calculate L2 norm.
norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
...
...
@@ -101,7 +101,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
# Count embedding layer only once (in first stage).
# Don't count the weights a second time in the last stage.
if
"embedding"
not
in
n
or
\
is_
inter_layer
_first_stage
():
is_
pipeline
_first_stage
():
filtered_parameters
.
append
(
p
)
parameters
=
filtered_parameters
else
:
...
...
@@ -123,7 +123,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
else
:
total_norm
=
0
for
p
in
parameters
:
if
p
.
intra_laye
r_model_parallel
or
(
get_
intra_laye
r_model_parallel_rank
()
==
0
):
if
p
.
tenso
r_model_parallel
or
(
get_
tenso
r_model_parallel_rank
()
==
0
):
param_norm
=
p
.
grad
.
data
.
norm
(
norm_type
)
total_norm
+=
param_norm
.
item
()
**
norm_type
# Sum across all model-parallel GPUs.
...
...
megatron/mpu/initialize.py
View file @
52a5f2f2
...
...
@@ -22,10 +22,10 @@ from .utils import ensure_divisibility
# Intra-layer model parallel group that the current rank belongs to.
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
=
None
_
TENSO
R_MODEL_PARALLEL_GROUP
=
None
# Inter-layer model parallel group that the current rank belongs to.
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
=
None
# Model parallel group (both intra- and
inter-layer
) that the current rank belongs to.
_
PIPELINE
_MODEL_PARALLEL_GROUP
=
None
# Model parallel group (both intra- and
pipeline
) that the current rank belongs to.
_MODEL_PARALLEL_GROUP
=
None
# Embedding group.
_EMBEDDING_GROUP
=
None
...
...
@@ -33,10 +33,10 @@ _EMBEDDING_GROUP = None
_DATA_PARALLEL_GROUP
=
None
# These values enable us to change the mpu sizes on the fly.
_MPU_
INTRA_LAYER
_WORLD_SIZE
=
None
_MPU_
INTER_LAYER
_WORLD_SIZE
=
None
_MPU_
INTRA_LAYER
_RANK
=
None
_MPU_
INTER_LAYER
_RANK
=
None
_MPU_
TENSOR_MODEL_PARALLEL
_WORLD_SIZE
=
None
_MPU_
PIPELINE_MODEL_PARALLEL
_WORLD_SIZE
=
None
_MPU_
TENSOR_MODEL_PARALLEL
_RANK
=
None
_MPU_
PIPELINE_MODEL_PARALLEL
_RANK
=
None
def
is_unitialized
():
...
...
@@ -44,25 +44,25 @@ def is_unitialized():
return
_DATA_PARALLEL_GROUP
is
None
def
initialize_model_parallel
(
intra_laye
r_model_parallel_size_
=
1
,
inter_layer
_model_parallel_size_
=
1
):
def
initialize_model_parallel
(
tenso
r_model_parallel_size_
=
1
,
pipeline
_model_parallel_size_
=
1
):
"""
Initialize model data parallel groups.
Arguments:
intra_laye
r_model_parallel_size: number of GPUs used to parallelize model
intra-laye
r.
inter_layer
_model_parallel_size: number of GPUs used to parallelize model
inter-layer
.
tenso
r_model_parallel_size: number of GPUs used to parallelize model
tenso
r.
pipeline
_model_parallel_size: number of GPUs used to parallelize model
pipeline
.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model
intra-laye
r, and 4 GPUs to parallelize
the model
inter-layer
. The present function will
create 8
intra-laye
r model-parallel groups, 4
inter-layer
model-parallel groups
use 2 GPUs to parallelize the model
tenso
r, and 4 GPUs to parallelize
the model
pipeline
. The present function will
create 8
tenso
r model-parallel groups, 4
pipeline
model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8
intra-laye
r model-parallel groups:
8
tenso
r model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4
inter-layer
model-parallel groups:
4
pipeline
model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
...
...
@@ -70,22 +70,22 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
ranks 8 to 15 belong to the second box.
"""
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> initializing
intra-laye
r model parallel with size {}'
.
format
(
intra_laye
r_model_parallel_size_
))
print
(
'> initializing
inter-layer
model parallel with size {}'
.
format
(
inter_layer
_model_parallel_size_
))
print
(
'> initializing
tenso
r model parallel with size {}'
.
format
(
tenso
r_model_parallel_size_
))
print
(
'> initializing
pipeline
model parallel with size {}'
.
format
(
pipeline
_model_parallel_size_
))
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
intra_laye
r_model_parallel_size
=
min
(
intra_laye
r_model_parallel_size_
,
world_size
)
inter_layer
_model_parallel_size
=
min
(
inter_layer
_model_parallel_size_
,
world_size
)
tenso
r_model_parallel_size
=
min
(
tenso
r_model_parallel_size_
,
world_size
)
pipeline
_model_parallel_size
=
min
(
pipeline
_model_parallel_size_
,
world_size
)
ensure_divisibility
(
world_size
,
intra_laye
r_model_parallel_size
*
inter_layer
_model_parallel_size
)
data_parallel_size
=
world_size
//
(
intra_laye
r_model_parallel_size
*
inter_layer
_model_parallel_size
)
tenso
r_model_parallel_size
*
pipeline
_model_parallel_size
)
data_parallel_size
=
world_size
//
(
tenso
r_model_parallel_size
*
pipeline
_model_parallel_size
)
num_
intra_laye
r_model_parallel_groups
=
world_size
//
intra_laye
r_model_parallel_size
num_
inter_layer
_model_parallel_groups
=
world_size
//
inter_layer
_model_parallel_size
num_
tenso
r_model_parallel_groups
=
world_size
//
tenso
r_model_parallel_size
num_
pipeline
_model_parallel_groups
=
world_size
//
pipeline
_model_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
rank
=
torch
.
distributed
.
get_rank
()
...
...
@@ -95,12 +95,12 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
assert
_DATA_PARALLEL_GROUP
is
None
,
\
'data parallel group is already initialized'
all_data_parallel_group_ranks
=
[]
for
i
in
range
(
inter_layer
_model_parallel_size
):
start_rank
=
i
*
num_
inter_layer
_model_parallel_groups
end_rank
=
(
i
+
1
)
*
num_
inter_layer
_model_parallel_groups
for
j
in
range
(
intra_laye
r_model_parallel_size
):
for
i
in
range
(
pipeline
_model_parallel_size
):
start_rank
=
i
*
num_
pipeline
_model_parallel_groups
end_rank
=
(
i
+
1
)
*
num_
pipeline
_model_parallel_groups
for
j
in
range
(
tenso
r_model_parallel_size
):
ranks
=
range
(
start_rank
+
j
,
end_rank
,
intra_laye
r_model_parallel_size
)
tenso
r_model_parallel_size
)
all_data_parallel_group_ranks
.
append
(
list
(
ranks
))
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
...
...
@@ -117,31 +117,31 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
if
rank
in
ranks
:
_MODEL_PARALLEL_GROUP
=
group
# Build the
intra-laye
r model-parallel groups.
global
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
assert
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
is
None
,
\
'
intra-laye
r model parallel group is already initialized'
for
i
in
range
(
num_
intra_laye
r_model_parallel_groups
):
ranks
=
range
(
i
*
intra_laye
r_model_parallel_size
,
(
i
+
1
)
*
intra_laye
r_model_parallel_size
)
# Build the
tenso
r model-parallel groups.
global
_
TENSO
R_MODEL_PARALLEL_GROUP
assert
_
TENSO
R_MODEL_PARALLEL_GROUP
is
None
,
\
'
tenso
r model parallel group is already initialized'
for
i
in
range
(
num_
tenso
r_model_parallel_groups
):
ranks
=
range
(
i
*
tenso
r_model_parallel_size
,
(
i
+
1
)
*
tenso
r_model_parallel_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
=
group
_
TENSO
R_MODEL_PARALLEL_GROUP
=
group
# Build the
inter-layer
model-parallel groups and embedding groups
# (first and last rank in each
inter-layer
model-parallel group).
global
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
assert
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
is
None
,
\
'
inter-layer
model parallel group is already initialized'
# Build the
pipeline
model-parallel groups and embedding groups
# (first and last rank in each
pipeline
model-parallel group).
global
_
PIPELINE
_MODEL_PARALLEL_GROUP
assert
_
PIPELINE
_MODEL_PARALLEL_GROUP
is
None
,
\
'
pipeline
model parallel group is already initialized'
global
_EMBEDDING_GROUP
assert
_EMBEDDING_GROUP
is
None
,
\
'embedding group is already initialized'
for
i
in
range
(
num_
inter_layer
_model_parallel_groups
):
for
i
in
range
(
num_
pipeline
_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
num_
inter_layer
_model_parallel_groups
)
num_
pipeline
_model_parallel_groups
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
=
group
_
PIPELINE
_MODEL_PARALLEL_GROUP
=
group
# Setup embedding group (to exchange gradients between
# first and last stages).
if
len
(
ranks
)
>
1
:
...
...
@@ -155,8 +155,8 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
def
model_parallel_is_initialized
():
"""Check if model and data parallel groups are initialized."""
if
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
is
None
or
\
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
is
None
or
\
if
_
TENSO
R_MODEL_PARALLEL_GROUP
is
None
or
\
_
PIPELINE
_MODEL_PARALLEL_GROUP
is
None
or
\
_DATA_PARALLEL_GROUP
is
None
:
return
False
return
True
...
...
@@ -169,18 +169,18 @@ def get_model_parallel_group():
return
_MODEL_PARALLEL_GROUP
def
get_
intra_laye
r_model_parallel_group
():
"""Get the
intra-laye
r model parallel group the caller rank belongs to."""
assert
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
is
not
None
,
\
def
get_
tenso
r_model_parallel_group
():
"""Get the
tenso
r model parallel group the caller rank belongs to."""
assert
_
TENSO
R_MODEL_PARALLEL_GROUP
is
not
None
,
\
'intra_layer_model parallel group is not initialized'
return
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
return
_
TENSO
R_MODEL_PARALLEL_GROUP
def
get_
inter_layer
_model_parallel_group
():
"""Get the
inter-layer
model parallel group the caller rank belongs to."""
assert
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
is
not
None
,
\
'
inter_layer
_model parallel group is not initialized'
return
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
def
get_
pipeline
_model_parallel_group
():
"""Get the
pipeline
model parallel group the caller rank belongs to."""
assert
_
PIPELINE
_MODEL_PARALLEL_GROUP
is
not
None
,
\
'
pipeline
_model parallel group is not initialized'
return
_
PIPELINE
_MODEL_PARALLEL_GROUP
def
get_data_parallel_group
():
...
...
@@ -197,87 +197,87 @@ def get_embedding_group():
return
_EMBEDDING_GROUP
def
set_
intra_laye
r_model_parallel_world_size
(
world_size
):
"""Set the
intra-laye
r model parallel size"""
global
_MPU_
INTRA_LAYER
_WORLD_SIZE
_MPU_
INTRA_LAYER
_WORLD_SIZE
=
world_size
def
set_
tenso
r_model_parallel_world_size
(
world_size
):
"""Set the
tenso
r model parallel size"""
global
_MPU_
TENSOR_MODEL_PARALLEL
_WORLD_SIZE
_MPU_
TENSOR_MODEL_PARALLEL
_WORLD_SIZE
=
world_size
def
set_
inter_layer
_model_parallel_world_size
(
world_size
):
"""Set the
inter-layer
model parallel size"""
global
_MPU_
INTER_LAYER
_WORLD_SIZE
_MPU_
INTER_LAYER
_WORLD_SIZE
=
world_size
def
set_
pipeline
_model_parallel_world_size
(
world_size
):
"""Set the
pipeline
model parallel size"""
global
_MPU_
PIPELINE_MODEL_PARALLEL
_WORLD_SIZE
_MPU_
PIPELINE_MODEL_PARALLEL
_WORLD_SIZE
=
world_size
def
get_
intra_laye
r_model_parallel_world_size
():
"""Return world size for the
intra-laye
r model parallel group."""
global
_MPU_
INTRA_LAYER
_WORLD_SIZE
if
_MPU_
INTRA_LAYER
_WORLD_SIZE
is
not
None
:
return
_MPU_
INTRA_LAYER
_WORLD_SIZE
return
torch
.
distributed
.
get_world_size
(
group
=
get_
intra_laye
r_model_parallel_group
())
def
get_
tenso
r_model_parallel_world_size
():
"""Return world size for the
tenso
r model parallel group."""
global
_MPU_
TENSOR_MODEL_PARALLEL
_WORLD_SIZE
if
_MPU_
TENSOR_MODEL_PARALLEL
_WORLD_SIZE
is
not
None
:
return
_MPU_
TENSOR_MODEL_PARALLEL
_WORLD_SIZE
return
torch
.
distributed
.
get_world_size
(
group
=
get_
tenso
r_model_parallel_group
())
def
get_
inter_layer
_model_parallel_world_size
():
"""Return world size for the
inter-layer
model parallel group."""
global
_MPU_
INTER_LAYER
_WORLD_SIZE
if
_MPU_
INTER_LAYER
_WORLD_SIZE
is
not
None
:
return
_MPU_
INTER_LAYER
_WORLD_SIZE
return
torch
.
distributed
.
get_world_size
(
group
=
get_
inter_layer
_model_parallel_group
())
def
get_
pipeline
_model_parallel_world_size
():
"""Return world size for the
pipeline
model parallel group."""
global
_MPU_
PIPELINE_MODEL_PARALLEL
_WORLD_SIZE
if
_MPU_
PIPELINE_MODEL_PARALLEL
_WORLD_SIZE
is
not
None
:
return
_MPU_
PIPELINE_MODEL_PARALLEL
_WORLD_SIZE
return
torch
.
distributed
.
get_world_size
(
group
=
get_
pipeline
_model_parallel_group
())
def
set_
intra_laye
r_model_parallel_rank
(
rank
):
"""Set
intra-laye
r model parallel rank."""
global
_MPU_
INTRA_LAYER
_RANK
_MPU_
INTRA_LAYER
_RANK
=
rank
def
set_
tenso
r_model_parallel_rank
(
rank
):
"""Set
tenso
r model parallel rank."""
global
_MPU_
TENSOR_MODEL_PARALLEL
_RANK
_MPU_
TENSOR_MODEL_PARALLEL
_RANK
=
rank
def
set_
inter_layer
_model_parallel_rank
(
rank
):
"""Set
inter-layer
model parallel rank."""
global
_MPU_
INTER_LAYER
_RANK
_MPU_
INTER_LAYER
_RANK
=
rank
def
set_
pipeline
_model_parallel_rank
(
rank
):
"""Set
pipeline
model parallel rank."""
global
_MPU_
PIPELINE_MODEL_PARALLEL
_RANK
_MPU_
PIPELINE_MODEL_PARALLEL
_RANK
=
rank
def
get_
intra_laye
r_model_parallel_rank
():
"""Return my rank for the
intra-laye
r model parallel group."""
global
_MPU_
INTRA_LAYER
_RANK
if
_MPU_
INTRA_LAYER
_RANK
is
not
None
:
return
_MPU_
INTRA_LAYER
_RANK
return
torch
.
distributed
.
get_rank
(
group
=
get_
intra_laye
r_model_parallel_group
())
def
get_
tenso
r_model_parallel_rank
():
"""Return my rank for the
tenso
r model parallel group."""
global
_MPU_
TENSOR_MODEL_PARALLEL
_RANK
if
_MPU_
TENSOR_MODEL_PARALLEL
_RANK
is
not
None
:
return
_MPU_
TENSOR_MODEL_PARALLEL
_RANK
return
torch
.
distributed
.
get_rank
(
group
=
get_
tenso
r_model_parallel_group
())
def
get_
inter_layer
_model_parallel_rank
():
"""Return my rank for the
inter-layer
model parallel group."""
global
_MPU_
INTER_LAYER
_RANK
if
_MPU_
INTER_LAYER
_RANK
is
not
None
:
return
_MPU_
INTER_LAYER
_RANK
return
torch
.
distributed
.
get_rank
(
group
=
get_
inter_layer
_model_parallel_group
())
def
get_
pipeline
_model_parallel_rank
():
"""Return my rank for the
pipeline
model parallel group."""
global
_MPU_
PIPELINE_MODEL_PARALLEL
_RANK
if
_MPU_
PIPELINE_MODEL_PARALLEL
_RANK
is
not
None
:
return
_MPU_
PIPELINE_MODEL_PARALLEL
_RANK
return
torch
.
distributed
.
get_rank
(
group
=
get_
pipeline
_model_parallel_group
())
def
is_
inter_layer
_first_stage
():
"""Return True if in the first
inter-layer
model-parallel stage, False otherwise."""
return
get_
inter_layer
_model_parallel_rank
()
==
0
def
is_
pipeline
_first_stage
():
"""Return True if in the first
pipeline
model-parallel stage, False otherwise."""
return
get_
pipeline
_model_parallel_rank
()
==
0
def
is_
inter_layer
_last_stage
():
"""Return True if in the last
inter-layer
model-parallel stage, False otherwise."""
return
get_
inter_layer
_model_parallel_rank
()
==
(
get_
inter_layer
_model_parallel_world_size
()
-
1
)
def
is_
pipeline
_last_stage
():
"""Return True if in the last
pipeline
model-parallel stage, False otherwise."""
return
get_
pipeline
_model_parallel_rank
()
==
(
get_
pipeline
_model_parallel_world_size
()
-
1
)
def
get_
intra_laye
r_model_parallel_src_rank
():
def
get_
tenso
r_model_parallel_src_rank
():
"""Calculate the global rank corresponding to a local rank
in the
intra-laye
r model parallel group."""
in the
tenso
r model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
local_world_size
=
get_
intra_laye
r_model_parallel_world_size
()
local_world_size
=
get_
tenso
r_model_parallel_world_size
()
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_
inter_layer
_model_parallel_src_rank
():
def
get_
pipeline
_model_parallel_src_rank
():
"""Calculate the global rank corresponding to a local rank
in the
inter-layer
model parallel group."""
in the
pipeline
model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
global_world_size
=
torch
.
distributed
.
get_world_size
()
local_world_size
=
get_
inter_layer
_model_parallel_world_size
()
local_world_size
=
get_
pipeline
_model_parallel_world_size
()
return
global_rank
%
(
global_world_size
//
local_world_size
)
...
...
@@ -293,9 +293,9 @@ def get_data_parallel_rank():
def
destroy_model_parallel
():
"""Set the groups to none."""
global
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
=
None
global
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
=
None
global
_
TENSO
R_MODEL_PARALLEL_GROUP
_
TENSO
R_MODEL_PARALLEL_GROUP
=
None
global
_
PIPELINE
_MODEL_PARALLEL_GROUP
_
PIPELINE
_MODEL_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP
=
None
megatron/mpu/layers.py
View file @
52a5f2f2
...
...
@@ -35,12 +35,12 @@ except Exception as e:
'instead of apex.normalization.FusedLayerNorm!'
)
from
torch.nn
import
LayerNorm
from
.initialize
import
get_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
intra_laye
r_model_parallel_world_size
from
.mappings
import
copy_to_
intra_laye
r_model_parallel_region
from
.mappings
import
gather_from_
intra_laye
r_model_parallel_region
from
.mappings
import
reduce_from_
intra_laye
r_model_parallel_region
from
.mappings
import
scatter_to_
intra_laye
r_model_parallel_region
from
.initialize
import
get_
tenso
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_world_size
from
.mappings
import
copy_to_
tenso
r_model_parallel_region
from
.mappings
import
gather_from_
tenso
r_model_parallel_region
from
.mappings
import
reduce_from_
tenso
r_model_parallel_region
from
.mappings
import
scatter_to_
tenso
r_model_parallel_region
from
.random
import
get_cuda_rng_tracker
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
...
...
@@ -51,7 +51,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
partition_dim
,
stride
=
1
):
"""Initialize affine weight for model parallel on GPU."""
weight
.
intra_laye
r_model_parallel
=
True
weight
.
tenso
r_model_parallel
=
True
weight
.
partition_dim
=
partition_dim
weight
.
partition_stride
=
stride
...
...
@@ -68,7 +68,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
Build the master weight on all processes and scatter
the relevant chunk."""
weight
.
intra_laye
r_model_parallel
=
True
weight
.
tenso
r_model_parallel
=
True
weight
.
partition_dim
=
partition_dim
weight
.
partition_stride
=
stride
...
...
@@ -85,7 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
weight_list
=
torch
.
split
(
master_weight
,
per_partition_per_stride_size
,
dim
=
partition_dim
)
rank
=
get_model_parallel_rank
()
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
my_weight_list
=
weight_list
[
rank
::
world_size
]
with
torch
.
no_grad
():
...
...
@@ -119,12 +119,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
scale_grad_by_freq
=
False
self
.
sparse
=
False
self
.
_weight
=
None
self
.
intra_laye
r_model_parallel_size
=
get_
intra_laye
r_model_parallel_world_size
()
self
.
tenso
r_model_parallel_size
=
get_
tenso
r_model_parallel_world_size
()
# Divide the weight matrix along the vocaburaly dimension.
self
.
vocab_start_index
,
self
.
vocab_end_index
=
\
VocabUtility
.
vocab_range_from_global_vocab_size
(
self
.
num_embeddings
,
get_
intra_laye
r_model_parallel_rank
(),
self
.
intra_laye
r_model_parallel_size
)
self
.
num_embeddings
,
get_
tenso
r_model_parallel_rank
(),
self
.
tenso
r_model_parallel_size
)
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
\
self
.
vocab_start_index
...
...
@@ -145,7 +145,7 @@ class VocabParallelEmbedding(torch.nn.Module):
partition_dim
=
0
,
stride
=
1
)
def
forward
(
self
,
input_
):
if
self
.
intra_laye
r_model_parallel_size
>
1
:
if
self
.
tenso
r_model_parallel_size
>
1
:
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
\
(
input_
>=
self
.
vocab_end_index
)
...
...
@@ -160,10 +160,10 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
# Mask the output embedding.
if
self
.
intra_laye
r_model_parallel_size
>
1
:
if
self
.
tenso
r_model_parallel_size
>
1
:
output_parallel
[
input_mask
,
:]
=
0.0
# Reduce across all the model parallel GPUs.
output
=
reduce_from_
intra_laye
r_model_parallel_region
(
output_parallel
)
output
=
reduce_from_
tenso
r_model_parallel_region
(
output_parallel
)
return
output
...
...
@@ -202,7 +202,7 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
output_size
=
output_size
self
.
gather_output
=
gather_output
# Divide the weight matrix along the last dimension.
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
...
...
@@ -235,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
output_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
self
.
bias
.
intra_laye
r_model_parallel
=
True
self
.
bias
.
tenso
r_model_parallel
=
True
self
.
bias
.
partition_dim
=
0
self
.
bias
.
stride
=
stride
# Always initialize bias to zero.
...
...
@@ -248,14 +248,14 @@ class ColumnParallelLinear(torch.nn.Module):
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
input_parallel
=
copy_to_
intra_laye
r_model_parallel_region
(
input_
)
input_parallel
=
copy_to_
tenso
r_model_parallel_region
(
input_
)
# Matrix multiply.
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_from_
intra_laye
r_model_parallel_region
(
output_parallel
)
output
=
gather_from_
tenso
r_model_parallel_region
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
...
...
@@ -304,7 +304,7 @@ class RowParallelLinear(torch.nn.Module):
self
.
output_size
=
output_size
self
.
input_is_parallel
=
input_is_parallel
# Divide the weight matrix along the last dimension.
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
...
...
@@ -348,11 +348,11 @@ class RowParallelLinear(torch.nn.Module):
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
input_parallel
=
scatter_to_
intra_laye
r_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_
tenso
r_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
# All-reduce across all the partitions.
output_
=
reduce_from_
intra_laye
r_model_parallel_region
(
output_parallel
)
output_
=
reduce_from_
tenso
r_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
...
...
megatron/mpu/mappings.py
View file @
52a5f2f2
...
...
@@ -15,7 +15,7 @@
import
torch
from
.initialize
import
get_
intra_laye
r_model_parallel_group
,
get_
intra_laye
r_model_parallel_world_size
,
get_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_group
,
get_
tenso
r_model_parallel_world_size
,
get_
tenso
r_model_parallel_rank
from
.utils
import
split_tensor_along_last_dim
...
...
@@ -23,11 +23,11 @@ def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if
get_
intra_laye
r_model_parallel_world_size
()
==
1
:
if
get_
tenso
r_model_parallel_world_size
()
==
1
:
return
input_
# All-reduce.
torch
.
distributed
.
all_reduce
(
input_
,
group
=
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
all_reduce
(
input_
,
group
=
get_
tenso
r_model_parallel_group
())
return
input_
...
...
@@ -36,7 +36,7 @@ def _split(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
...
...
@@ -45,7 +45,7 @@ def _split(input_):
input_list
=
split_tensor_along_last_dim
(
input_
,
world_size
)
# Note: torch.split does not create contiguous tensors by default.
rank
=
get_
intra_laye
r_model_parallel_rank
()
rank
=
get_
tenso
r_model_parallel_rank
()
output
=
input_list
[
rank
].
contiguous
()
return
output
...
...
@@ -54,18 +54,18 @@ def _split(input_):
def
_gather
(
input_
):
"""Gather tensors and concatinate along the last dimension."""
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
# Size and dimension.
last_dim
=
input_
.
dim
()
-
1
rank
=
get_
intra_laye
r_model_parallel_rank
()
rank
=
get_
tenso
r_model_parallel_rank
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
get_
tenso
r_model_parallel_group
())
# Note: torch.cat already creates a contiguous tensor.
output
=
torch
.
cat
(
tensor_list
,
dim
=
last_dim
).
contiguous
()
...
...
@@ -141,17 +141,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# Helper functions.
# -----------------
def
copy_to_
intra_laye
r_model_parallel_region
(
input_
):
def
copy_to_
tenso
r_model_parallel_region
(
input_
):
return
_CopyToModelParallelRegion
.
apply
(
input_
)
def
reduce_from_
intra_laye
r_model_parallel_region
(
input_
):
def
reduce_from_
tenso
r_model_parallel_region
(
input_
):
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_
intra_laye
r_model_parallel_region
(
input_
):
def
scatter_to_
tenso
r_model_parallel_region
(
input_
):
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
def
gather_from_
intra_laye
r_model_parallel_region
(
input_
):
def
gather_from_
tenso
r_model_parallel_region
(
input_
):
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
megatron/mpu/random.py
View file @
52a5f2f2
...
...
@@ -28,13 +28,13 @@ from megatron import get_args
from
megatron.memory
import
allocate_mem_buff
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_
intra_laye
r_model_parallel_group
from
.initialize
import
get_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
intra_laye
r_model_parallel_world_size
from
.initialize
import
get_
tenso
r_model_parallel_group
from
.initialize
import
get_
tenso
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_world_size
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'
intra-laye
r-model-parallel-rng'
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'
tenso
r-model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
...
...
@@ -104,15 +104,15 @@ def _set_cuda_rng_state(new_state, device=-1):
def
split_tensor_into_1d_equal_chunks
(
tensor
):
"""Break a tensor into equal 1D chunks."""
data
=
tensor
.
view
(
-
1
)
partition_size
=
torch
.
numel
(
data
)
//
get_
intra_laye
r_model_parallel_world_size
()
start_index
=
partition_size
*
get_
intra_laye
r_model_parallel_rank
()
partition_size
=
torch
.
numel
(
data
)
//
get_
tenso
r_model_parallel_world_size
()
start_index
=
partition_size
*
get_
tenso
r_model_parallel_rank
()
end_index
=
start_index
+
partition_size
return
data
[
start_index
:
end_index
]
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
numel
=
torch
.
numel
(
tensor
)
numel_gathered
=
world_size
*
numel
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
...
...
@@ -120,7 +120,7 @@ def gather_split_1d_tensor(tensor):
requires_grad
=
False
)
chunks
=
[
gathered
[
i
*
numel
:(
i
+
1
)
*
numel
]
for
i
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
chunks
,
tensor
,
group
=
get_
intra_laye
r_model_parallel_group
())
group
=
get_
tenso
r_model_parallel_group
())
return
gathered
...
...
@@ -204,7 +204,7 @@ def get_cuda_rng_tracker():
return
_CUDA_RNG_STATE_TRACKER
def
intra_layer_
model_parallel_cuda_manual_seed
(
seed
):
def
model_parallel_cuda_manual_seed
(
seed
):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
...
...
@@ -215,15 +215,15 @@ def intra_layer_model_parallel_cuda_manual_seed(seed):
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-
intra-laye
r-model-parallel regions.
intra-laye
r-model-parallel state: This state is different among a set of model
example for dropout in the non-
tenso
r-model-parallel regions.
tenso
r-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
offset
=
seed
+
2718
intra_laye
r_model_parallel_seed
=
offset
+
get_
intra_laye
r_model_parallel_rank
()
tenso
r_model_parallel_seed
=
offset
+
get_
tenso
r_model_parallel_rank
()
# Data parallel gets the original sedd.
data_parallel_seed
=
seed
...
...
@@ -231,15 +231,15 @@ def intra_layer_model_parallel_cuda_manual_seed(seed):
print
(
'> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
get_
intra_laye
r_model_parallel_rank
(),
get_data_parallel_rank
(),
intra_laye
r_model_parallel_seed
,
torch
.
distributed
.
get_rank
(),
get_
tenso
r_model_parallel_rank
(),
get_data_parallel_rank
(),
tenso
r_model_parallel_seed
,
data_parallel_seed
),
flush
=
True
)
_CUDA_RNG_STATE_TRACKER
.
reset
()
# Set the default state.
torch
.
cuda
.
manual_seed
(
data_parallel_seed
)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER
.
add
(
_MODEL_PARALLEL_RNG_TRACKER_NAME
,
intra_laye
r_model_parallel_seed
)
tenso
r_model_parallel_seed
)
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
...
...
megatron/mpu/tests/commons.py
View file @
52a5f2f2
...
...
@@ -36,7 +36,7 @@ def set_random_seed(seed):
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
mpu
.
intra_layer_
model_parallel_cuda_manual_seed
(
seed
)
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
def
initialize_distributed
(
backend
=
'nccl'
):
...
...
megatron/mpu/tests/test_cross_entropy.py
View file @
52a5f2f2
...
...
@@ -47,7 +47,7 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
identity
=
IdentityLayer
((
batch_size
,
seq_length
,
vocab_size
),
scale
=
logits_scale
).
cuda
()
logits
=
identity
()
logits_parallel
=
mpu
.
scatter_to_
intra_laye
r_model_parallel_region
(
logits
)
logits_parallel
=
mpu
.
scatter_to_
tenso
r_model_parallel_region
(
logits
)
target
=
torch
.
cuda
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
)
loss
=
vocab_parallel_cross_entropy
(
logits_parallel
,
target
).
mean
()
...
...
@@ -55,20 +55,20 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
return
loss
,
identity
.
weight
.
grad
def
test_cross_entropy
(
intra_laye
r_model_parallel_size
):
def
test_cross_entropy
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing cross entropy with model parallel size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
format
(
tenso
r_model_parallel_size
))
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
batch_size
=
13
seq_length
=
17
vocab_size_per_partition
=
11
logits_scale
=
1000.0
vocab_size
=
vocab_size_per_partition
*
intra_laye
r_model_parallel_size
vocab_size
=
vocab_size_per_partition
*
tenso
r_model_parallel_size
seed
=
1234
loss_torch
,
grad_torch
=
torch_cross_entropy
(
batch_size
,
seq_length
,
...
...
@@ -89,7 +89,7 @@ def test_cross_entropy(intra_layer_model_parallel_size):
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_
intra_laye
r_model_parallel
()
mpu
.
destroy_
tenso
r_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
...
...
@@ -101,8 +101,8 @@ if __name__ == '__main__':
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
intra_laye
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
tenso
r_model_parallel_size
=
1
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test cross entropy'
)
test_cross_entropy
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
test_cross_entropy
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
*=
2
megatron/mpu/tests/test_data.py
View file @
52a5f2f2
...
...
@@ -24,15 +24,15 @@ import sys
sys
.
path
.
append
(
"../.."
)
def
test_broadcast_data
(
intra_laye
r_model_parallel_size
):
def
test_broadcast_data
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing broadcast_data with model parallel size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
format
(
tenso
r_model_parallel_size
))
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
torch
.
manual_seed
(
1234
+
mpu
.
get_data_parallel_rank
())
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
key_size_t
=
{
'key1'
:
[
7
,
11
],
'key2'
:
[
8
,
2
,
1
],
...
...
@@ -48,7 +48,7 @@ def test_broadcast_data(intra_layer_model_parallel_size):
data_t
[
key
]
=
data
[
key
].
clone
()
data
[
'keyX'
]
=
torch
.
FloatTensor
(
size
=
(
5
,
)).
random_
(
0
,
1000
)
data_t
[
'keyX'
]
=
data
[
'keyX'
].
clone
()
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
!=
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
!=
0
:
data
=
None
data_utils
.
_check_data_types
(
keys
,
data_t
,
torch
.
int64
)
...
...
@@ -69,7 +69,7 @@ def test_broadcast_data(intra_layer_model_parallel_size):
assert
data_b
[
key
].
sub
(
tensor
).
abs
().
max
()
==
0
# Reset groups
mpu
.
destroy_
intra_laye
r_model_parallel
()
mpu
.
destroy_
tenso
r_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
...
...
@@ -81,8 +81,8 @@ if __name__ == '__main__':
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
intra_laye
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
tenso
r_model_parallel_size
=
1
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test test broadcast data'
)
test_broadcast_data
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
test_broadcast_data
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
*=
2
megatron/mpu/tests/test_initialize.py
View file @
52a5f2f2
...
...
@@ -21,15 +21,15 @@ import sys
sys
.
path
.
append
(
"../.."
)
def
test_initialize_model_parallel
(
intra_laye
r_model_parallel_size
):
def
test_initialize_model_parallel
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing initialize_model_parallel with size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
intra_laye
r_model_parallel_size_
=
min
(
intra_laye
r_model_parallel_size
,
tenso
r_model_parallel_size
))
tenso
r_model_parallel_size_
=
min
(
tenso
r_model_parallel_size
,
torch
.
distributed
.
get_world_size
())
assert
not
mpu
.
model_parallel_is_initialized
()
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size_
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size_
)
assert
mpu
.
model_parallel_is_initialized
()
# Checks.
...
...
@@ -38,15 +38,15 @@ def test_initialize_model_parallel(intra_layer_model_parallel_size):
assert
rank
==
torch
.
distributed
.
get_rank
(
group
=
group
)
# Model parallel.
world_size
=
intra_laye
r_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
%
intra_laye
r_model_parallel_size_
assert
world_size
==
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
assert
rank
==
mpu
.
get_
intra_laye
r_model_parallel_rank
()
check
(
mpu
.
get_
intra_laye
r_model_parallel_group
(),
world_size
,
rank
)
world_size
=
tenso
r_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
%
tenso
r_model_parallel_size_
assert
world_size
==
mpu
.
get_
tenso
r_model_parallel_world_size
()
assert
rank
==
mpu
.
get_
tenso
r_model_parallel_rank
()
check
(
mpu
.
get_
tenso
r_model_parallel_group
(),
world_size
,
rank
)
# Data parallel.
world_size
=
torch
.
distributed
.
get_world_size
()
//
intra_laye
r_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
//
intra_laye
r_model_parallel_size
world_size
=
torch
.
distributed
.
get_world_size
()
//
tenso
r_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
//
tenso
r_model_parallel_size
assert
world_size
==
mpu
.
get_data_parallel_world_size
()
assert
rank
==
mpu
.
get_data_parallel_rank
()
check
(
mpu
.
get_data_parallel_group
(),
world_size
,
rank
)
...
...
@@ -59,20 +59,20 @@ def test_initialize_model_parallel(intra_layer_model_parallel_size):
print
(
'>> passed the test :-)'
)
def
test_get_
intra_laye
r_model_parallel_src_rank
(
intra_laye
r_model_parallel_size_
):
def
test_get_
tenso
r_model_parallel_src_rank
(
tenso
r_model_parallel_size_
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing get_
intra_laye
r_model_parallel_src_rank with size {} ...'
.
format
(
intra_laye
r_model_parallel_size_
))
intra_laye
r_model_parallel_size
=
min
(
intra_laye
r_model_parallel_size_
,
print
(
'> testing get_
tenso
r_model_parallel_src_rank with size {} ...'
.
format
(
tenso
r_model_parallel_size_
))
tenso
r_model_parallel_size
=
min
(
tenso
r_model_parallel_size_
,
torch
.
distributed
.
get_world_size
())
assert
not
mpu
.
model_parallel_is_initialized
()
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
assert
mpu
.
model_parallel_is_initialized
()
# Checks
src_rank
=
torch
.
distributed
.
get_rank
()
-
mpu
.
get_
intra_laye
r_model_parallel_rank
()
assert
mpu
.
get_
intra_laye
r_model_parallel_src_rank
()
==
src_rank
src_rank
=
torch
.
distributed
.
get_rank
()
-
mpu
.
get_
tenso
r_model_parallel_rank
()
assert
mpu
.
get_
tenso
r_model_parallel_src_rank
()
==
src_rank
# Reset groups
mpu
.
destroy_model_parallel
()
...
...
@@ -86,10 +86,10 @@ if __name__ == '__main__':
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
intra_laye
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
tenso
r_model_parallel_size
=
1
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test initialize model parallel'
)
test_initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
test_initialize_model_parallel
(
tenso
r_model_parallel_size
)
print_separator
(
'test model parallel source rank'
)
test_get_
intra_laye
r_model_parallel_src_rank
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
test_get_
tenso
r_model_parallel_src_rank
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
*=
2
megatron/mpu/tests/test_layers.py
View file @
52a5f2f2
...
...
@@ -26,14 +26,14 @@ import sys
sys
.
path
.
append
(
"../.."
)
def
test_parallel_embedding
(
intra_laye
r_model_parallel_size
):
def
test_parallel_embedding
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing parallel embedding with model parallel size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
format
(
tenso
r_model_parallel_size
))
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
batch_size
=
17
seq_length
=
23
...
...
@@ -80,16 +80,16 @@ def test_parallel_embedding(intra_layer_model_parallel_size):
assert
error
<
1.0e-12
,
'error: {}'
.
format
(
error
)
weight_grad_orig
=
torch
.
split
(
embedding_original
.
weight
.
grad
,
hidden_size
//
intra_laye
r_model_parallel_size
,
1
)[
mpu
.
get_
intra_laye
r_model_parallel_rank
()]
hidden_size
//
tenso
r_model_parallel_size
,
1
)[
mpu
.
get_
tenso
r_model_parallel_rank
()]
error
=
embedding_parallel
.
weight
.
grad
.
sub
(
weight_grad_orig
).
abs
().
max
()
print
(
' error in grad (parallel) on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
'error: {}'
.
format
(
error
)
weight_grad_orig
=
torch
.
split
(
embedding_original
.
weight
.
grad
,
vocab_size
//
intra_laye
r_model_parallel_size
,
0
)[
mpu
.
get_
intra_laye
r_model_parallel_rank
()]
vocab_size
//
tenso
r_model_parallel_size
,
0
)[
mpu
.
get_
tenso
r_model_parallel_rank
()]
error
=
embedding_vocab_parallel
.
weight
.
grad
.
sub
(
weight_grad_orig
).
abs
().
max
()
print
(
' error in grad (vocab parallel) on global rank {}: {}'
.
format
(
...
...
@@ -104,19 +104,19 @@ def test_parallel_embedding(intra_layer_model_parallel_size):
print
(
'>> passed the test :-)'
)
def
test_initialize_affine_weight
(
intra_laye
r_model_parallel_size
):
def
test_initialize_affine_weight
(
tenso
r_model_parallel_size
):
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing initialize_affine_weight with model parallel '
'size: {}'
.
format
(
intra_laye
r_model_parallel_size
))
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
'size: {}'
.
format
(
tenso
r_model_parallel_size
))
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
seed
=
12345
input_size_coeff
=
13
input_size
=
input_size_coeff
*
intra_laye
r_model_parallel_size
input_size
=
input_size_coeff
*
tenso
r_model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
intra_laye
r_model_parallel_size
output_size
=
output_size_coeff
*
tenso
r_model_parallel_size
# ---------------
# Column parallel
...
...
@@ -131,7 +131,7 @@ def test_initialize_affine_weight(intra_layer_model_parallel_size):
set_random_seed
(
seed
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
)
torch
.
nn
.
init
.
normal_
(
master_weight
)
rank
=
mpu
.
get_
intra_laye
r_model_parallel_rank
()
rank
=
mpu
.
get_
tenso
r_model_parallel_rank
()
my_weight
=
torch
.
split
(
master_weight
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
...
...
@@ -154,7 +154,7 @@ def test_initialize_affine_weight(intra_layer_model_parallel_size):
set_random_seed
(
seed
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
)
torch
.
nn
.
init
.
normal_
(
master_weight
)
rank
=
mpu
.
get_
intra_laye
r_model_parallel_rank
()
rank
=
mpu
.
get_
tenso
r_model_parallel_rank
()
my_weight
=
torch
.
split
(
master_weight
,
input_size_coeff
,
dim
=
1
)[
rank
].
contiguous
().
clone
()
...
...
@@ -183,20 +183,20 @@ class IdentityLayer2D(torch.nn.Module):
return
self
.
weight
def
test_column_parallel_linear
(
intra_laye
r_model_parallel_size
):
def
test_column_parallel_linear
(
tenso
r_model_parallel_size
):
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing ColumnParallelLinear with model parallel '
'size: {}'
.
format
(
intra_laye
r_model_parallel_size
))
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
'size: {}'
.
format
(
tenso
r_model_parallel_size
))
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
seed
=
12345
set_random_seed
(
seed
)
input_size_coeff
=
13
input_size
=
input_size_coeff
*
intra_laye
r_model_parallel_size
input_size
=
input_size_coeff
*
tenso
r_model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
intra_laye
r_model_parallel_size
output_size
=
output_size_coeff
*
tenso
r_model_parallel_size
batch_size
=
7
# Network
...
...
@@ -219,7 +219,7 @@ def test_column_parallel_linear(intra_layer_model_parallel_size):
dLdb
=
torch
.
matmul
(
torch
.
ones
(
batch_size
,
1
).
cuda
().
t
(),
dLdY
).
view
(
-
1
)
dLdX
=
torch
.
matmul
(
dLdY
,
A
)
rank
=
mpu
.
get_
intra_laye
r_model_parallel_rank
()
rank
=
mpu
.
get_
tenso
r_model_parallel_rank
()
my_dLdA
=
torch
.
split
(
dLdA
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdA
.
sub
(
linear_layer
.
weight
.
grad
).
abs
().
max
()
...
...
@@ -250,20 +250,20 @@ def test_column_parallel_linear(intra_layer_model_parallel_size):
print
(
' >> passed the test :-)'
)
def
test_row_parallel_linear
(
intra_laye
r_model_parallel_size
):
def
test_row_parallel_linear
(
tenso
r_model_parallel_size
):
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing RowParallelLinear with model parallel '
'size: {}'
.
format
(
intra_laye
r_model_parallel_size
))
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
'size: {}'
.
format
(
tenso
r_model_parallel_size
))
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
seed
=
12345
set_random_seed
(
seed
)
input_size_coeff
=
13
input_size
=
input_size_coeff
*
intra_laye
r_model_parallel_size
input_size
=
input_size_coeff
*
tenso
r_model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
intra_laye
r_model_parallel_size
output_size
=
output_size_coeff
*
tenso
r_model_parallel_size
batch_size
=
7
# Network
...
...
@@ -286,7 +286,7 @@ def test_row_parallel_linear(intra_layer_model_parallel_size):
dLdb
=
torch
.
matmul
(
torch
.
ones
(
batch_size
,
1
).
cuda
().
t
(),
dLdY
).
view
(
-
1
)
dLdX
=
torch
.
matmul
(
dLdY
,
A
)
rank
=
mpu
.
get_
intra_laye
r_model_parallel_rank
()
rank
=
mpu
.
get_
tenso
r_model_parallel_rank
()
my_dLdA
=
torch
.
split
(
dLdA
,
input_size_coeff
,
dim
=
1
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdA
.
sub
(
linear_layer
.
weight
.
grad
).
abs
().
max
()
...
...
@@ -325,11 +325,11 @@ class IdentityLayer3D(torch.nn.Module):
return
self
.
weight
def
parallel_self_attention
(
intra_laye
r_model_parallel_size
,
num_att_heads_per_partition
,
def
parallel_self_attention
(
tenso
r_model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
):
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
seed
=
12345
set_random_seed
(
seed
)
...
...
@@ -352,17 +352,17 @@ def parallel_self_attention(intra_layer_model_parallel_size, num_att_heads_per_p
# Backward
loss
.
backward
()
rank
=
mpu
.
get_
intra_laye
r_model_parallel_rank
()
rank
=
mpu
.
get_
tenso
r_model_parallel_rank
()
mpu
.
destroy_model_parallel
()
return
rank
,
hidden_size
,
intra_laye
r_model_parallel_size
,
loss
,
\
return
rank
,
hidden_size
,
tenso
r_model_parallel_size
,
loss
,
\
attention_layer
,
identity_layer
def
test_parallel_self_attention
(
intra_laye
r_model_parallel_size
):
def
test_parallel_self_attention
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing ParallelSelfAttention with model parallel '
'size: {}'
.
format
(
intra_laye
r_model_parallel_size
))
'size: {}'
.
format
(
tenso
r_model_parallel_size
))
num_att_heads_per_partition
=
3
hidden_size_per_att_head
=
7
...
...
@@ -370,14 +370,14 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
batch_size
=
5
sequence_length
=
13
rank_1
,
hideen_size_1
,
intra_laye
r_model_parallel_size_1
,
loss_1
,
\
rank_1
,
hideen_size_1
,
tenso
r_model_parallel_size_1
,
loss_1
,
\
attention_layer_1
,
identity_layer_1
=
parallel_self_attention
(
1
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
rank
,
hidden_size
,
intra_laye
r_model_parallel_size
,
loss
,
\
rank
,
hidden_size
,
tenso
r_model_parallel_size
,
loss
,
\
attention_layer
,
identity_layer
=
parallel_self_attention
(
intra_laye
r_model_parallel_size
,
num_att_heads_per_partition
,
tenso
r_model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
assert
hideen_size_1
==
hidden_size
...
...
@@ -389,7 +389,7 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
my_lin_grad_list
=
torch
.
split
(
attention_layer_1
.
query_key_value
.
weight
.
grad
,
hidden_size
//
intra_laye
r_model_parallel_size
,
0
)[
rank
::
intra_laye
r_model_parallel_size
]
hidden_size
//
tenso
r_model_parallel_size
,
0
)[
rank
::
tenso
r_model_parallel_size
]
my_lin_grad
=
torch
.
cat
(
my_lin_grad_list
,
dim
=
0
)
error
=
my_lin_grad
.
sub
(
attention_layer
.
query_key_value
.
weight
.
grad
).
abs
().
max
()
...
...
@@ -410,11 +410,11 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
print
(
' >> passed the test :-)'
)
def
parallel_transformer
(
intra_laye
r_model_parallel_size
,
num_att_heads_per_partition
,
def
parallel_transformer
(
tenso
r_model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
batch_size
,
sequence_length
):
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
seed
=
12345
set_random_seed
(
seed
)
...
...
@@ -440,31 +440,31 @@ def parallel_transformer(intra_layer_model_parallel_size, num_att_heads_per_part
# Backward
loss
.
backward
()
rank
=
mpu
.
get_
intra_laye
r_model_parallel_rank
()
rank
=
mpu
.
get_
tenso
r_model_parallel_rank
()
mpu
.
destroy_model_parallel
()
return
rank
,
hidden_size
,
intra_laye
r_model_parallel_size
,
loss
,
\
return
rank
,
hidden_size
,
tenso
r_model_parallel_size
,
loss
,
\
transformer_layer
,
identity_layer
def
test_parallel_transformer_layer
(
intra_laye
r_model_parallel_size
):
def
test_parallel_transformer_layer
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing ParallelTransformerLayer with model parallel '
'size: {}'
.
format
(
intra_laye
r_model_parallel_size
))
'size: {}'
.
format
(
tenso
r_model_parallel_size
))
num_att_heads_per_partition
=
3
hidden_size_per_att_head
=
7
batch_size
=
5
sequence_length
=
13
rank_1
,
hidden_size_1
,
intra_laye
r_model_parallel_size_1
,
loss_1
,
\
rank_1
,
hidden_size_1
,
tenso
r_model_parallel_size_1
,
loss_1
,
\
transformer_layer_1
,
identity_layer_1
=
parallel_transformer
(
1
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
batch_size
,
sequence_length
)
rank
,
hidden_size
,
intra_laye
r_model_parallel_size
,
loss
,
\
rank
,
hidden_size
,
tenso
r_model_parallel_size
,
loss
,
\
transformer_layer
,
identity_layer
=
parallel_transformer
(
intra_laye
r_model_parallel_size
,
num_att_heads_per_partition
,
tenso
r_model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
batch_size
,
sequence_length
)
error
=
loss_1
.
sub
(
loss
).
abs
().
max
()
...
...
@@ -494,37 +494,37 @@ if __name__ == '__main__':
world_size
=
torch
.
distributed
.
get_world_size
()
print_separator
(
'test initialize affine weight'
)
intra_laye
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
test_initialize_affine_weight
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
=
1
while
tenso
r_model_parallel_size
<=
world_size
:
test_initialize_affine_weight
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
*=
2
intra_laye
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
tenso
r_model_parallel_size
=
1
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test parallel embedding'
)
test_parallel_embedding
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
test_parallel_embedding
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
*=
2
print_separator
(
'test column-parallel linear'
)
intra_laye
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
test_column_parallel_linear
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
=
1
while
tenso
r_model_parallel_size
<=
world_size
:
test_column_parallel_linear
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
*=
2
print_separator
(
'test row-parallel linear'
)
intra_laye
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
test_row_parallel_linear
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
=
1
while
tenso
r_model_parallel_size
<=
world_size
:
test_row_parallel_linear
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
*=
2
print_separator
(
'test parallel self-attention'
)
intra_laye
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
test_parallel_self_attention
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
=
1
while
tenso
r_model_parallel_size
<=
world_size
:
test_parallel_self_attention
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
*=
2
print_separator
(
'test parallel transformer'
)
intra_laye
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
test_parallel_transformer_layer
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
=
1
while
tenso
r_model_parallel_size
<=
world_size
:
test_parallel_transformer_layer
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
*=
2
megatron/mpu/tests/test_random.py
View file @
52a5f2f2
...
...
@@ -21,14 +21,14 @@ import sys
sys
.
path
.
append
(
"../.."
)
def
test_set_cuda_rng_state
(
intra_laye
r_model_parallel_size
):
def
test_set_cuda_rng_state
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing set_rng_state with size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
format
(
tenso
r_model_parallel_size
))
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
size
=
123
seed
=
1234
...
...
@@ -83,14 +83,14 @@ def test_set_cuda_rng_state(intra_layer_model_parallel_size):
print
(
'>> passed the test :-)'
)
def
test_cuda_rng_tracker
(
intra_laye
r_model_parallel_size
):
def
test_cuda_rng_tracker
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing cuda rng tracker with size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
format
(
tenso
r_model_parallel_size
))
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
seed_1
=
1234
seed_2
=
4321
...
...
@@ -154,20 +154,20 @@ def test_cuda_rng_tracker(intra_layer_model_parallel_size):
print
(
'>> passed the test :-)'
)
def
test_
intra_layer_
model_parallel_cuda_manual_seed
(
intra_laye
r_model_parallel_size
):
def
test_model_parallel_cuda_manual_seed
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing model parallel cuda manual seed with size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
format
(
tenso
r_model_parallel_size
))
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
mpu
.
intra_layer_
model_parallel_cuda_manual_seed
(
12345
)
mpu
.
model_parallel_cuda_manual_seed
(
12345
)
assert
torch
.
cuda
.
initial_seed
()
==
12345
with
mpu
.
get_cuda_rng_tracker
().
fork
():
assert
torch
.
cuda
.
initial_seed
()
==
(
12345
+
2718
+
mpu
.
get_
intra_laye
r_model_parallel_rank
())
mpu
.
get_
tenso
r_model_parallel_rank
())
# Reset the tracker
mpu
.
get_cuda_rng_tracker
().
reset
()
...
...
@@ -185,20 +185,20 @@ if __name__ == '__main__':
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
intra_laye
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
tenso
r_model_parallel_size
=
1
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test set rng state'
)
test_set_cuda_rng_state
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
test_set_cuda_rng_state
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
*=
2
intra_laye
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
tenso
r_model_parallel_size
=
1
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test cuda rng tracker'
)
test_cuda_rng_tracker
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
test_cuda_rng_tracker
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
*=
2
intra_laye
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
tenso
r_model_parallel_size
=
1
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test model parallel cuda manual seed'
)
test_
intra_layer_
model_parallel_cuda_manual_seed
(
intra_laye
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
test_model_parallel_cuda_manual_seed
(
tenso
r_model_parallel_size
)
tenso
r_model_parallel_size
*=
2
megatron/text_generation_utils.py
View file @
52a5f2f2
...
...
@@ -88,7 +88,7 @@ def generate_samples_input_from_file(model):
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
...
...
@@ -105,10 +105,10 @@ def generate_samples_input_from_file(model):
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
terminate_runs
=
0
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
raw_text
=
all_raw_text
[
input_pos
]
input_pos
+=
1
if
input_pos
==
input_count
:
...
...
@@ -131,8 +131,8 @@ def generate_samples_input_from_file(model):
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_
intra_laye
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
mpu
.
get_
tenso
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
if
terminate_runs
==
1
:
...
...
@@ -143,7 +143,7 @@ def generate_samples_input_from_file(model):
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
detokenize
(
...
...
@@ -158,7 +158,7 @@ def generate_samples_input_from_file(model):
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
context_count
+=
1
...
...
@@ -171,10 +171,10 @@ def generate_samples_interactive(model, print_frequency=24):
model
.
eval
()
with
torch
.
no_grad
():
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
terminate_runs
=
0
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
while
not
raw_text
:
...
...
@@ -198,8 +198,8 @@ def generate_samples_interactive(model, print_frequency=24):
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_
intra_laye
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
mpu
.
get_
tenso
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
if
terminate_runs
==
1
:
...
...
@@ -210,7 +210,7 @@ def generate_samples_interactive(model, print_frequency=24):
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
and
\
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
and
\
counter
%
print_frequency
==
0
:
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
...
...
@@ -218,7 +218,7 @@ def generate_samples_interactive(model, print_frequency=24):
decode_tokens
)[
len
(
raw_text
):]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
detokenize
(
...
...
@@ -226,10 +226,10 @@ def generate_samples_interactive(model, print_frequency=24):
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
context_count
+=
1
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
input
(
"
\n
Press any key to continue >>>"
)
...
...
@@ -299,11 +299,11 @@ def get_token_stream(model, context_tokens):
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
torch
.
distributed
.
broadcast
(
context_length_tensor
,
mpu
.
get_
intra_laye
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
mpu
.
get_
tenso
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_
intra_laye
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
mpu
.
get_
tenso
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
...
...
megatron/tokenizer/tokenizer.py
View file @
52a5f2f2
...
...
@@ -56,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args):
after
=
orig_vocab_size
multiple
=
args
.
make_vocab_size_divisible_by
*
\
args
.
intra_laye
r_model_parallel_size
args
.
tenso
r_model_parallel_size
while
(
after
%
multiple
)
!=
0
:
after
+=
1
if
args
.
rank
==
0
:
...
...
megatron/training.py
View file @
52a5f2f2
...
...
@@ -124,10 +124,10 @@ def get_model(model_provider_func):
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' > number of parameters on (
intra-layer, inter-layer
) '
print
(
' > number of parameters on (
tensor, pipeline
) '
'model parallel rank ({}, {}): {}'
.
format
(
mpu
.
get_
intra_laye
r_model_parallel_rank
(),
mpu
.
get_
inter_layer
_model_parallel_rank
(),
mpu
.
get_
tenso
r_model_parallel_rank
(),
mpu
.
get_
pipeline
_model_parallel_rank
(),
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])),
flush
=
True
)
# GPU allocation.
...
...
@@ -166,8 +166,8 @@ def get_optimizer(model):
# Add model parallel attribute if it is not set.
for
param_group
in
param_groups
:
for
param
in
param_group
[
'params'
]:
if
not
hasattr
(
param
,
'
intra_laye
r_model_parallel'
):
param
.
intra_laye
r_model_parallel
=
False
if
not
hasattr
(
param
,
'
tenso
r_model_parallel'
):
param
.
tenso
r_model_parallel
=
False
# Use Adam.
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
...
...
@@ -260,7 +260,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_recv_next
=
tensor_recv_next
,
group
=
mpu
.
get_
inter_layer
_model_parallel_group
())
group
=
mpu
.
get_
pipeline
_model_parallel_group
())
return
tensor_recv_prev
,
tensor_recv_next
...
...
@@ -304,7 +304,7 @@ def train_step(forward_step_func, data_iterator,
optimizer
.
zero_grad
()
# Compute number of microbatches in a minibatch.
num_microbatches_to_pipeline
=
args
.
inter_layer
_model_parallel_size
\
num_microbatches_to_pipeline
=
args
.
pipeline
_model_parallel_size
\
if
args
.
use_pipelining
else
1
input_tensors
=
[]
...
...
@@ -313,7 +313,7 @@ def train_step(forward_step_func, data_iterator,
# Run forward pass for all microbatches in minibatch.
for
i
in
range
(
num_microbatches_to_pipeline
):
if
not
mpu
.
is_
inter_layer
_first_stage
():
if
not
mpu
.
is_
pipeline
_first_stage
():
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
...
...
@@ -327,7 +327,7 @@ def train_step(forward_step_func, data_iterator,
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
timers
(
'forward'
).
stop
()
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
losses_reduced
.
append
(
loss_reduced
)
...
...
@@ -346,7 +346,7 @@ def train_step(forward_step_func, data_iterator,
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
output_grad_tensor
=
None
else
:
_
,
output_grad_tensor
=
communicate
(
...
...
@@ -362,7 +362,7 @@ def train_step(forward_step_func, data_iterator,
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_grad_tensor
)
timers
(
'backward'
).
stop
()
if
not
mpu
.
is_
inter_layer
_first_stage
():
if
not
mpu
.
is_
pipeline
_first_stage
():
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_grad_tensor
,
...
...
@@ -383,8 +383,8 @@ def train_step(forward_step_func, data_iterator,
timers
(
'backward-master-grad'
).
stop
()
# All-reduce across first and last stages.
if
(
mpu
.
is_
inter_layer
_first_stage
()
or
mpu
.
is_
inter_layer
_last_stage
())
and
\
args
.
inter_layer
_model_parallel_size
>
1
:
if
(
mpu
.
is_
pipeline
_first_stage
()
or
mpu
.
is_
pipeline
_last_stage
())
and
\
args
.
pipeline
_model_parallel_size
>
1
:
unwrapped_model
=
model
while
isinstance
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16_Module
)):
unwrapped_model
=
unwrapped_model
.
module
...
...
@@ -421,7 +421,7 @@ def train_step(forward_step_func, data_iterator,
else
:
skipped_iter
=
1
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
# Average loss across microbatches.
loss_reduced
=
{}
for
key
in
losses_reduced
[
0
]:
...
...
@@ -604,7 +604,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
args
.
eval_iters
))
if
not
mpu
.
is_
inter_layer
_first_stage
():
if
not
mpu
.
is_
pipeline
_first_stage
():
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
...
...
@@ -616,7 +616,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
# Forward evaluation.
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
_
,
loss_dict
=
output_tensor
# Reduce across processes.
for
key
in
loss_dict
:
...
...
@@ -671,7 +671,7 @@ def build_train_valid_test_data_iterators(
print_rank_0
(
'> building train, validation, and test datasets ...'
)
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
# Rank, size, and global batch size.
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
global_batch_size
=
args
.
batch_size
*
data_parallel_size
...
...
@@ -709,8 +709,8 @@ def build_train_valid_test_data_iterators(
# Broadcast num tokens.
torch
.
distributed
.
broadcast
(
flags
,
mpu
.
get_
intra_laye
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
mpu
.
get_
tenso
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
args
.
do_train
=
flags
[
0
].
item
()
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
...
...
megatron/utils.py
View file @
52a5f2f2
...
...
@@ -58,7 +58,7 @@ def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index
=
0
rank
=
torch
.
distributed
.
get_rank
()
string
=
'iteration, rank, index,
intra-laye
r-model-parallel, min, max, norm
\n
'
string
=
'iteration, rank, index,
tenso
r-model-parallel, min, max, norm
\n
'
optimizer_
=
optimizer
if
isinstance
(
optimizer
,
FP16_Optimizer
):
optimizer_
=
optimizer
.
optimizer
...
...
@@ -69,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration):
max_
=
param
.
data
.
max
()
norm
=
param
.
data
.
norm
()
string
+=
'{:7d}, {:4d}, {:4d}, {:2d}, '
.
format
(
iteration
,
rank
,
index
,
int
(
param
.
intra_laye
r_model_parallel
))
iteration
,
rank
,
index
,
int
(
param
.
tenso
r_model_parallel
))
string
+=
'{:.6E}, {:.6E}, {:.6E}
\n
'
.
format
(
min_
,
max_
,
norm
)
print
(
string
,
flush
=
True
)
...
...
pretrain_bert.py
View file @
52a5f2f2
...
...
@@ -34,12 +34,12 @@ def model_provider():
print_rank_0
(
'building BERT model ...'
)
args
=
get_args
()
if
args
.
inter_layer
_model_parallel_size
>
1
:
if
args
.
pipeline
_model_parallel_size
>
1
:
# Determine model based on position of stage in pipeline.
if
mpu
.
is_
inter_layer
_first_stage
():
if
mpu
.
is_
pipeline
_first_stage
():
model
=
BertModelFirstStage
(
num_tokentypes
=
2
)
elif
mpu
.
is_
inter_layer
_last_stage
():
elif
mpu
.
is_
pipeline
_last_stage
():
model
=
BertModelLastStage
(
num_tokentypes
=
2
,
add_binary_head
=
True
,
...
...
@@ -93,21 +93,21 @@ def forward_step(data_iterator, model, input_tensor):
timers
(
'batch generator'
).
stop
()
# Forward pass through the model.
if
mpu
.
is_
inter_layer
_first_stage
():
if
mpu
.
is_
pipeline
_first_stage
():
assert
input_tensor
is
None
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
output_tensor
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
,
lm_labels
=
lm_labels
)
else
:
output_tensor
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
)
elif
mpu
.
is_
inter_layer
_last_stage
():
elif
mpu
.
is_
pipeline
_last_stage
():
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
padding_mask
,
lm_labels
=
lm_labels
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
padding_mask
)
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
lm_loss_
,
sop_logits
=
output_tensor
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
float
(),
...
...
pretrain_gpt2.py
View file @
52a5f2f2
...
...
@@ -33,11 +33,11 @@ def model_provider():
print_rank_0
(
'building GPT2 model ...'
)
args
=
get_args
()
if
args
.
inter_layer
_model_parallel_size
>
1
:
if
args
.
pipeline
_model_parallel_size
>
1
:
# Determine model based on position of stage in pipeline.
if
mpu
.
is_
inter_layer
_first_stage
():
if
mpu
.
is_
pipeline
_first_stage
():
model
=
GPT2ModelFirstStage
(
num_tokentypes
=
0
)
elif
mpu
.
is_
inter_layer
_last_stage
():
elif
mpu
.
is_
pipeline
_last_stage
():
model
=
GPT2ModelLastStage
(
num_tokentypes
=
0
,
parallel_output
=
True
)
else
:
...
...
@@ -93,21 +93,21 @@ def forward_step(data_iterator, model, input_tensor):
timers
(
'batch generator'
).
stop
()
# Forward pass through the model.
if
mpu
.
is_
inter_layer
_first_stage
():
if
mpu
.
is_
pipeline
_first_stage
():
assert
input_tensor
is
None
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
)
elif
mpu
.
is_
inter_layer
_last_stage
():
elif
mpu
.
is_
pipeline
_last_stage
():
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
,
labels
=
labels
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
)
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
losses
=
output_tensor
.
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment