Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
52a5f2f2
Commit
52a5f2f2
authored
Oct 20, 2020
by
Deepak Narayanan
Browse files
Intra-layer MP -> Tensor MP, Inter-layer MP -> Pipeline MP
parent
7abd3e90
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
403 additions
and
403 deletions
+403
-403
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+14
-14
megatron/mpu/cross_entropy.py
megatron/mpu/cross_entropy.py
+8
-8
megatron/mpu/data.py
megatron/mpu/data.py
+9
-9
megatron/mpu/grads.py
megatron/mpu/grads.py
+6
-6
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+116
-116
megatron/mpu/layers.py
megatron/mpu/layers.py
+22
-22
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+12
-12
megatron/mpu/random.py
megatron/mpu/random.py
+15
-15
megatron/mpu/tests/commons.py
megatron/mpu/tests/commons.py
+1
-1
megatron/mpu/tests/test_cross_entropy.py
megatron/mpu/tests/test_cross_entropy.py
+11
-11
megatron/mpu/tests/test_data.py
megatron/mpu/tests/test_data.py
+10
-10
megatron/mpu/tests/test_initialize.py
megatron/mpu/tests/test_initialize.py
+23
-23
megatron/mpu/tests/test_layers.py
megatron/mpu/tests/test_layers.py
+75
-75
megatron/mpu/tests/test_random.py
megatron/mpu/tests/test_random.py
+26
-26
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+19
-19
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+1
-1
megatron/training.py
megatron/training.py
+19
-19
megatron/utils.py
megatron/utils.py
+2
-2
pretrain_bert.py
pretrain_bert.py
+7
-7
pretrain_gpt2.py
pretrain_gpt2.py
+7
-7
No files found.
megatron/mpu/__init__.py
View file @
52a5f2f2
...
@@ -28,15 +28,15 @@ from .initialize import get_data_parallel_rank
...
@@ -28,15 +28,15 @@ from .initialize import get_data_parallel_rank
from
.initialize
import
get_data_parallel_world_size
from
.initialize
import
get_data_parallel_world_size
from
.initialize
import
get_embedding_group
from
.initialize
import
get_embedding_group
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_
intra_laye
r_model_parallel_group
from
.initialize
import
get_
tenso
r_model_parallel_group
from
.initialize
import
get_
inter_layer
_model_parallel_group
from
.initialize
import
get_
pipeline
_model_parallel_group
from
.initialize
import
get_
intra_laye
r_model_parallel_rank
,
set_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_rank
,
set_
tenso
r_model_parallel_rank
from
.initialize
import
get_
inter_layer
_model_parallel_rank
,
set_
inter_layer
_model_parallel_rank
from
.initialize
import
get_
pipeline
_model_parallel_rank
,
set_
pipeline
_model_parallel_rank
from
.initialize
import
is_
inter_layer
_first_stage
,
is_
inter_layer
_last_stage
from
.initialize
import
is_
pipeline
_first_stage
,
is_
pipeline
_last_stage
from
.initialize
import
get_
intra_laye
r_model_parallel_src_rank
from
.initialize
import
get_
tenso
r_model_parallel_src_rank
from
.initialize
import
get_
inter_layer
_model_parallel_src_rank
from
.initialize
import
get_
pipeline
_model_parallel_src_rank
from
.initialize
import
get_
intra_laye
r_model_parallel_world_size
,
set_
intra_laye
r_model_parallel_world_size
from
.initialize
import
get_
tenso
r_model_parallel_world_size
,
set_
tenso
r_model_parallel_world_size
from
.initialize
import
get_
inter_layer
_model_parallel_world_size
,
set_
inter_layer
_model_parallel_world_size
from
.initialize
import
get_
pipeline
_model_parallel_world_size
,
set_
pipeline
_model_parallel_world_size
from
.initialize
import
initialize_model_parallel
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.initialize
import
model_parallel_is_initialized
...
@@ -45,15 +45,15 @@ from .layers import ColumnParallelLinear
...
@@ -45,15 +45,15 @@ from .layers import ColumnParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
VocabParallelEmbedding
from
.layers
import
VocabParallelEmbedding
from
.mappings
import
copy_to_
intra_laye
r_model_parallel_region
from
.mappings
import
copy_to_
tenso
r_model_parallel_region
from
.mappings
import
gather_from_
intra_laye
r_model_parallel_region
from
.mappings
import
gather_from_
tenso
r_model_parallel_region
from
.mappings
import
reduce_from_
intra_laye
r_model_parallel_region
from
.mappings
import
reduce_from_
tenso
r_model_parallel_region
from
.mappings
import
scatter_to_
intra_laye
r_model_parallel_region
from
.mappings
import
scatter_to_
tenso
r_model_parallel_region
from
.random
import
checkpoint
from
.random
import
checkpoint
from
.random
import
get_cuda_rng_tracker
from
.random
import
get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
intra_layer_
model_parallel_cuda_manual_seed
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.utils
import
divide
from
.utils
import
divide
...
...
megatron/mpu/cross_entropy.py
View file @
52a5f2f2
...
@@ -16,9 +16,9 @@
...
@@ -16,9 +16,9 @@
import
torch
import
torch
from
.initialize
import
get_
intra_laye
r_model_parallel_group
from
.initialize
import
get_
tenso
r_model_parallel_group
from
.initialize
import
get_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_rank
from
.initialize
import
get_
intra_laye
r_model_parallel_world_size
from
.initialize
import
get_
tenso
r_model_parallel_world_size
from
.utils
import
VocabUtility
from
.utils
import
VocabUtility
...
@@ -31,15 +31,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -31,15 +31,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
logits_max
=
torch
.
max
(
vocab_parallel_logits
,
dim
=-
1
)[
0
]
logits_max
=
torch
.
max
(
vocab_parallel_logits
,
dim
=-
1
)[
0
]
torch
.
distributed
.
all_reduce
(
logits_max
,
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_
intra_laye
r_model_parallel_group
())
group
=
get_
tenso
r_model_parallel_group
())
# Subtract the maximum value.
# Subtract the maximum value.
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
# Get the partition's vocab indecies
# Get the partition's vocab indecies
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
rank
=
get_
intra_laye
r_model_parallel_rank
()
rank
=
get_
tenso
r_model_parallel_rank
()
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
vocab_start_index
,
vocab_end_index
=
get_vocab_range
(
vocab_start_index
,
vocab_end_index
=
get_vocab_range
(
partition_vocab_size
,
rank
,
world_size
)
partition_vocab_size
,
rank
,
world_size
)
...
@@ -62,7 +62,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -62,7 +62,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# All reduce is needed to get the chunks from other GPUs.
# All reduce is needed to get the chunks from other GPUs.
torch
.
distributed
.
all_reduce
(
predicted_logits
,
torch
.
distributed
.
all_reduce
(
predicted_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_
intra_laye
r_model_parallel_group
())
group
=
get_
tenso
r_model_parallel_group
())
# Sum of exponential of logits along vocab dimension across all GPUs.
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
vocab_parallel_logits
exp_logits
=
vocab_parallel_logits
...
@@ -70,7 +70,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -70,7 +70,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_
intra_laye
r_model_parallel_group
())
group
=
get_
tenso
r_model_parallel_group
())
# Loss = log(sum(exp(logits))) - predicted-logit.
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
...
...
megatron/mpu/data.py
View file @
52a5f2f2
...
@@ -15,9 +15,9 @@
...
@@ -15,9 +15,9 @@
import
torch
import
torch
from
.initialize
import
get_
intra_laye
r_model_parallel_group
from
.initialize
import
get_
tenso
r_model_parallel_group
from
.initialize
import
get_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_rank
from
.initialize
import
get_
intra_laye
r_model_parallel_src_rank
from
.initialize
import
get_
tenso
r_model_parallel_src_rank
_MAX_DATA_DIM
=
4
_MAX_DATA_DIM
=
4
...
@@ -36,7 +36,7 @@ def _build_key_size_numel_dictionaries(keys, data):
...
@@ -36,7 +36,7 @@ def _build_key_size_numel_dictionaries(keys, data):
sizes
=
[
0
for
_
in
range
(
max_dim
)
for
_
in
keys
]
sizes
=
[
0
for
_
in
range
(
max_dim
)
for
_
in
keys
]
# Pack the sizes on rank zero.
# Pack the sizes on rank zero.
if
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
get_
tenso
r_model_parallel_rank
()
==
0
:
offset
=
0
offset
=
0
for
key
in
keys
:
for
key
in
keys
:
assert
data
[
key
].
dim
()
<
max_dim
,
'you should increase MAX_DATA_DIM'
assert
data
[
key
].
dim
()
<
max_dim
,
'you should increase MAX_DATA_DIM'
...
@@ -47,8 +47,8 @@ def _build_key_size_numel_dictionaries(keys, data):
...
@@ -47,8 +47,8 @@ def _build_key_size_numel_dictionaries(keys, data):
# Move to GPU and broadcast.
# Move to GPU and broadcast.
sizes_cuda
=
torch
.
cuda
.
LongTensor
(
sizes
)
sizes_cuda
=
torch
.
cuda
.
LongTensor
(
sizes
)
torch
.
distributed
.
broadcast
(
sizes_cuda
,
get_
intra_laye
r_model_parallel_src_rank
(),
torch
.
distributed
.
broadcast
(
sizes_cuda
,
get_
tenso
r_model_parallel_src_rank
(),
group
=
get_
intra_laye
r_model_parallel_group
())
group
=
get_
tenso
r_model_parallel_group
())
# Move back to cpu and unpack.
# Move back to cpu and unpack.
sizes_cpu
=
sizes_cuda
.
cpu
()
sizes_cpu
=
sizes_cuda
.
cpu
()
...
@@ -89,7 +89,7 @@ def broadcast_data(keys, data, datatype):
...
@@ -89,7 +89,7 @@ def broadcast_data(keys, data, datatype):
data
)
data
)
# Pack on rank zero.
# Pack on rank zero.
if
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
get_
tenso
r_model_parallel_rank
()
==
0
:
# Check that all keys have the same data type.
# Check that all keys have the same data type.
_check_data_types
(
keys
,
data
,
datatype
)
_check_data_types
(
keys
,
data
,
datatype
)
# Flatten the data associated with the keys
# Flatten the data associated with the keys
...
@@ -101,8 +101,8 @@ def broadcast_data(keys, data, datatype):
...
@@ -101,8 +101,8 @@ def broadcast_data(keys, data, datatype):
dtype
=
datatype
)
dtype
=
datatype
)
# Broadcast
# Broadcast
torch
.
distributed
.
broadcast
(
flatten_data
,
get_
intra_laye
r_model_parallel_src_rank
(),
torch
.
distributed
.
broadcast
(
flatten_data
,
get_
tenso
r_model_parallel_src_rank
(),
group
=
get_
intra_laye
r_model_parallel_group
())
group
=
get_
tenso
r_model_parallel_group
())
# Unpack
# Unpack
output
=
{}
output
=
{}
...
...
megatron/mpu/grads.py
View file @
52a5f2f2
...
@@ -28,9 +28,9 @@ try:
...
@@ -28,9 +28,9 @@ try:
except
Exception
as
e
:
except
Exception
as
e
:
print
(
'WARNING: APEX is not installed, multi_tensor_applier will not be available.'
)
print
(
'WARNING: APEX is not installed, multi_tensor_applier will not be available.'
)
from
.initialize
import
is_
inter_layer
_first_stage
from
.initialize
import
is_
pipeline
_first_stage
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_rank
def
l2_grad_clipper
(
parameters
,
max_norm
):
def
l2_grad_clipper
(
parameters
,
max_norm
):
...
@@ -44,9 +44,9 @@ def l2_grad_clipper(parameters, max_norm):
...
@@ -44,9 +44,9 @@ def l2_grad_clipper(parameters, max_norm):
parameters_with_grads
=
list
(
filter
(
parameters_with_grads
=
list
(
filter
(
lambda
p
:
p
.
grad
is
not
None
,
parameters
))
lambda
p
:
p
.
grad
is
not
None
,
parameters
))
# Filter parameters for norm calculations.
# Filter parameters for norm calculations.
mp_rank_is_zero
=
(
get_
intra_laye
r_model_parallel_rank
()
==
0
)
mp_rank_is_zero
=
(
get_
tenso
r_model_parallel_rank
()
==
0
)
parameters_for_norm
=
list
(
filter
(
parameters_for_norm
=
list
(
filter
(
lambda
p
:
p
.
intra_laye
r_model_parallel
or
mp_rank_is_zero
,
parameters_with_grads
))
lambda
p
:
p
.
tenso
r_model_parallel
or
mp_rank_is_zero
,
parameters_with_grads
))
# Calculate L2 norm.
# Calculate L2 norm.
norm
,
_
=
multi_tensor_applier
(
norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
amp_C
.
multi_tensor_l2norm
,
...
@@ -101,7 +101,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
...
@@ -101,7 +101,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
# Count embedding layer only once (in first stage).
# Count embedding layer only once (in first stage).
# Don't count the weights a second time in the last stage.
# Don't count the weights a second time in the last stage.
if
"embedding"
not
in
n
or
\
if
"embedding"
not
in
n
or
\
is_
inter_layer
_first_stage
():
is_
pipeline
_first_stage
():
filtered_parameters
.
append
(
p
)
filtered_parameters
.
append
(
p
)
parameters
=
filtered_parameters
parameters
=
filtered_parameters
else
:
else
:
...
@@ -123,7 +123,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
...
@@ -123,7 +123,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
else
:
else
:
total_norm
=
0
total_norm
=
0
for
p
in
parameters
:
for
p
in
parameters
:
if
p
.
intra_laye
r_model_parallel
or
(
get_
intra_laye
r_model_parallel_rank
()
==
0
):
if
p
.
tenso
r_model_parallel
or
(
get_
tenso
r_model_parallel_rank
()
==
0
):
param_norm
=
p
.
grad
.
data
.
norm
(
norm_type
)
param_norm
=
p
.
grad
.
data
.
norm
(
norm_type
)
total_norm
+=
param_norm
.
item
()
**
norm_type
total_norm
+=
param_norm
.
item
()
**
norm_type
# Sum across all model-parallel GPUs.
# Sum across all model-parallel GPUs.
...
...
megatron/mpu/initialize.py
View file @
52a5f2f2
...
@@ -22,10 +22,10 @@ from .utils import ensure_divisibility
...
@@ -22,10 +22,10 @@ from .utils import ensure_divisibility
# Intra-layer model parallel group that the current rank belongs to.
# Intra-layer model parallel group that the current rank belongs to.
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
=
None
_
TENSO
R_MODEL_PARALLEL_GROUP
=
None
# Inter-layer model parallel group that the current rank belongs to.
# Inter-layer model parallel group that the current rank belongs to.
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
=
None
_
PIPELINE
_MODEL_PARALLEL_GROUP
=
None
# Model parallel group (both intra- and
inter-layer
) that the current rank belongs to.
# Model parallel group (both intra- and
pipeline
) that the current rank belongs to.
_MODEL_PARALLEL_GROUP
=
None
_MODEL_PARALLEL_GROUP
=
None
# Embedding group.
# Embedding group.
_EMBEDDING_GROUP
=
None
_EMBEDDING_GROUP
=
None
...
@@ -33,10 +33,10 @@ _EMBEDDING_GROUP = None
...
@@ -33,10 +33,10 @@ _EMBEDDING_GROUP = None
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
# These values enable us to change the mpu sizes on the fly.
# These values enable us to change the mpu sizes on the fly.
_MPU_
INTRA_LAYER
_WORLD_SIZE
=
None
_MPU_
TENSOR_MODEL_PARALLEL
_WORLD_SIZE
=
None
_MPU_
INTER_LAYER
_WORLD_SIZE
=
None
_MPU_
PIPELINE_MODEL_PARALLEL
_WORLD_SIZE
=
None
_MPU_
INTRA_LAYER
_RANK
=
None
_MPU_
TENSOR_MODEL_PARALLEL
_RANK
=
None
_MPU_
INTER_LAYER
_RANK
=
None
_MPU_
PIPELINE_MODEL_PARALLEL
_RANK
=
None
def
is_unitialized
():
def
is_unitialized
():
...
@@ -44,25 +44,25 @@ def is_unitialized():
...
@@ -44,25 +44,25 @@ def is_unitialized():
return
_DATA_PARALLEL_GROUP
is
None
return
_DATA_PARALLEL_GROUP
is
None
def
initialize_model_parallel
(
intra_laye
r_model_parallel_size_
=
1
,
def
initialize_model_parallel
(
tenso
r_model_parallel_size_
=
1
,
inter_layer
_model_parallel_size_
=
1
):
pipeline
_model_parallel_size_
=
1
):
"""
"""
Initialize model data parallel groups.
Initialize model data parallel groups.
Arguments:
Arguments:
intra_laye
r_model_parallel_size: number of GPUs used to parallelize model
intra-laye
r.
tenso
r_model_parallel_size: number of GPUs used to parallelize model
tenso
r.
inter_layer
_model_parallel_size: number of GPUs used to parallelize model
inter-layer
.
pipeline
_model_parallel_size: number of GPUs used to parallelize model
pipeline
.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model
intra-laye
r, and 4 GPUs to parallelize
use 2 GPUs to parallelize the model
tenso
r, and 4 GPUs to parallelize
the model
inter-layer
. The present function will
the model
pipeline
. The present function will
create 8
intra-laye
r model-parallel groups, 4
inter-layer
model-parallel groups
create 8
tenso
r model-parallel groups, 4
pipeline
model-parallel groups
and 8 data-parallel groups as:
and 8 data-parallel groups as:
8 data_parallel groups:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8
intra-laye
r model-parallel groups:
8
tenso
r model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4
inter-layer
model-parallel groups:
4
pipeline
model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
are on the same DGX box. For example if we are using 2 DGX-1 boxes
...
@@ -70,22 +70,22 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
...
@@ -70,22 +70,22 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
ranks 8 to 15 belong to the second box.
ranks 8 to 15 belong to the second box.
"""
"""
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> initializing
intra-laye
r model parallel with size {}'
.
format
(
print
(
'> initializing
tenso
r model parallel with size {}'
.
format
(
intra_laye
r_model_parallel_size_
))
tenso
r_model_parallel_size_
))
print
(
'> initializing
inter-layer
model parallel with size {}'
.
format
(
print
(
'> initializing
pipeline
model parallel with size {}'
.
format
(
inter_layer
_model_parallel_size_
))
pipeline
_model_parallel_size_
))
# Get world size and rank. Ensure some consistencies.
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
assert
torch
.
distributed
.
is_initialized
()
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
intra_laye
r_model_parallel_size
=
min
(
intra_laye
r_model_parallel_size_
,
world_size
)
tenso
r_model_parallel_size
=
min
(
tenso
r_model_parallel_size_
,
world_size
)
inter_layer
_model_parallel_size
=
min
(
inter_layer
_model_parallel_size_
,
world_size
)
pipeline
_model_parallel_size
=
min
(
pipeline
_model_parallel_size_
,
world_size
)
ensure_divisibility
(
world_size
,
ensure_divisibility
(
world_size
,
intra_laye
r_model_parallel_size
*
inter_layer
_model_parallel_size
)
tenso
r_model_parallel_size
*
pipeline
_model_parallel_size
)
data_parallel_size
=
world_size
//
(
intra_laye
r_model_parallel_size
*
data_parallel_size
=
world_size
//
(
tenso
r_model_parallel_size
*
inter_layer
_model_parallel_size
)
pipeline
_model_parallel_size
)
num_
intra_laye
r_model_parallel_groups
=
world_size
//
intra_laye
r_model_parallel_size
num_
tenso
r_model_parallel_groups
=
world_size
//
tenso
r_model_parallel_size
num_
inter_layer
_model_parallel_groups
=
world_size
//
inter_layer
_model_parallel_size
num_
pipeline
_model_parallel_groups
=
world_size
//
pipeline
_model_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
...
@@ -95,12 +95,12 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
...
@@ -95,12 +95,12 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
assert
_DATA_PARALLEL_GROUP
is
None
,
\
assert
_DATA_PARALLEL_GROUP
is
None
,
\
'data parallel group is already initialized'
'data parallel group is already initialized'
all_data_parallel_group_ranks
=
[]
all_data_parallel_group_ranks
=
[]
for
i
in
range
(
inter_layer
_model_parallel_size
):
for
i
in
range
(
pipeline
_model_parallel_size
):
start_rank
=
i
*
num_
inter_layer
_model_parallel_groups
start_rank
=
i
*
num_
pipeline
_model_parallel_groups
end_rank
=
(
i
+
1
)
*
num_
inter_layer
_model_parallel_groups
end_rank
=
(
i
+
1
)
*
num_
pipeline
_model_parallel_groups
for
j
in
range
(
intra_laye
r_model_parallel_size
):
for
j
in
range
(
tenso
r_model_parallel_size
):
ranks
=
range
(
start_rank
+
j
,
end_rank
,
ranks
=
range
(
start_rank
+
j
,
end_rank
,
intra_laye
r_model_parallel_size
)
tenso
r_model_parallel_size
)
all_data_parallel_group_ranks
.
append
(
list
(
ranks
))
all_data_parallel_group_ranks
.
append
(
list
(
ranks
))
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
if
rank
in
ranks
:
...
@@ -117,31 +117,31 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
...
@@ -117,31 +117,31 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
if
rank
in
ranks
:
if
rank
in
ranks
:
_MODEL_PARALLEL_GROUP
=
group
_MODEL_PARALLEL_GROUP
=
group
# Build the
intra-laye
r model-parallel groups.
# Build the
tenso
r model-parallel groups.
global
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
global
_
TENSO
R_MODEL_PARALLEL_GROUP
assert
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
is
None
,
\
assert
_
TENSO
R_MODEL_PARALLEL_GROUP
is
None
,
\
'
intra-laye
r model parallel group is already initialized'
'
tenso
r model parallel group is already initialized'
for
i
in
range
(
num_
intra_laye
r_model_parallel_groups
):
for
i
in
range
(
num_
tenso
r_model_parallel_groups
):
ranks
=
range
(
i
*
intra_laye
r_model_parallel_size
,
ranks
=
range
(
i
*
tenso
r_model_parallel_size
,
(
i
+
1
)
*
intra_laye
r_model_parallel_size
)
(
i
+
1
)
*
tenso
r_model_parallel_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
=
group
_
TENSO
R_MODEL_PARALLEL_GROUP
=
group
# Build the
inter-layer
model-parallel groups and embedding groups
# Build the
pipeline
model-parallel groups and embedding groups
# (first and last rank in each
inter-layer
model-parallel group).
# (first and last rank in each
pipeline
model-parallel group).
global
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
global
_
PIPELINE
_MODEL_PARALLEL_GROUP
assert
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
is
None
,
\
assert
_
PIPELINE
_MODEL_PARALLEL_GROUP
is
None
,
\
'
inter-layer
model parallel group is already initialized'
'
pipeline
model parallel group is already initialized'
global
_EMBEDDING_GROUP
global
_EMBEDDING_GROUP
assert
_EMBEDDING_GROUP
is
None
,
\
assert
_EMBEDDING_GROUP
is
None
,
\
'embedding group is already initialized'
'embedding group is already initialized'
for
i
in
range
(
num_
inter_layer
_model_parallel_groups
):
for
i
in
range
(
num_
pipeline
_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
ranks
=
range
(
i
,
world_size
,
num_
inter_layer
_model_parallel_groups
)
num_
pipeline
_model_parallel_groups
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
=
group
_
PIPELINE
_MODEL_PARALLEL_GROUP
=
group
# Setup embedding group (to exchange gradients between
# Setup embedding group (to exchange gradients between
# first and last stages).
# first and last stages).
if
len
(
ranks
)
>
1
:
if
len
(
ranks
)
>
1
:
...
@@ -155,8 +155,8 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
...
@@ -155,8 +155,8 @@ def initialize_model_parallel(intra_layer_model_parallel_size_=1,
def
model_parallel_is_initialized
():
def
model_parallel_is_initialized
():
"""Check if model and data parallel groups are initialized."""
"""Check if model and data parallel groups are initialized."""
if
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
is
None
or
\
if
_
TENSO
R_MODEL_PARALLEL_GROUP
is
None
or
\
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
is
None
or
\
_
PIPELINE
_MODEL_PARALLEL_GROUP
is
None
or
\
_DATA_PARALLEL_GROUP
is
None
:
_DATA_PARALLEL_GROUP
is
None
:
return
False
return
False
return
True
return
True
...
@@ -169,18 +169,18 @@ def get_model_parallel_group():
...
@@ -169,18 +169,18 @@ def get_model_parallel_group():
return
_MODEL_PARALLEL_GROUP
return
_MODEL_PARALLEL_GROUP
def
get_
intra_laye
r_model_parallel_group
():
def
get_
tenso
r_model_parallel_group
():
"""Get the
intra-laye
r model parallel group the caller rank belongs to."""
"""Get the
tenso
r model parallel group the caller rank belongs to."""
assert
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
is
not
None
,
\
assert
_
TENSO
R_MODEL_PARALLEL_GROUP
is
not
None
,
\
'intra_layer_model parallel group is not initialized'
'intra_layer_model parallel group is not initialized'
return
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
return
_
TENSO
R_MODEL_PARALLEL_GROUP
def
get_
inter_layer
_model_parallel_group
():
def
get_
pipeline
_model_parallel_group
():
"""Get the
inter-layer
model parallel group the caller rank belongs to."""
"""Get the
pipeline
model parallel group the caller rank belongs to."""
assert
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
is
not
None
,
\
assert
_
PIPELINE
_MODEL_PARALLEL_GROUP
is
not
None
,
\
'
inter_layer
_model parallel group is not initialized'
'
pipeline
_model parallel group is not initialized'
return
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
return
_
PIPELINE
_MODEL_PARALLEL_GROUP
def
get_data_parallel_group
():
def
get_data_parallel_group
():
...
@@ -197,87 +197,87 @@ def get_embedding_group():
...
@@ -197,87 +197,87 @@ def get_embedding_group():
return
_EMBEDDING_GROUP
return
_EMBEDDING_GROUP
def
set_
intra_laye
r_model_parallel_world_size
(
world_size
):
def
set_
tenso
r_model_parallel_world_size
(
world_size
):
"""Set the
intra-laye
r model parallel size"""
"""Set the
tenso
r model parallel size"""
global
_MPU_
INTRA_LAYER
_WORLD_SIZE
global
_MPU_
TENSOR_MODEL_PARALLEL
_WORLD_SIZE
_MPU_
INTRA_LAYER
_WORLD_SIZE
=
world_size
_MPU_
TENSOR_MODEL_PARALLEL
_WORLD_SIZE
=
world_size
def
set_
inter_layer
_model_parallel_world_size
(
world_size
):
def
set_
pipeline
_model_parallel_world_size
(
world_size
):
"""Set the
inter-layer
model parallel size"""
"""Set the
pipeline
model parallel size"""
global
_MPU_
INTER_LAYER
_WORLD_SIZE
global
_MPU_
PIPELINE_MODEL_PARALLEL
_WORLD_SIZE
_MPU_
INTER_LAYER
_WORLD_SIZE
=
world_size
_MPU_
PIPELINE_MODEL_PARALLEL
_WORLD_SIZE
=
world_size
def
get_
intra_laye
r_model_parallel_world_size
():
def
get_
tenso
r_model_parallel_world_size
():
"""Return world size for the
intra-laye
r model parallel group."""
"""Return world size for the
tenso
r model parallel group."""
global
_MPU_
INTRA_LAYER
_WORLD_SIZE
global
_MPU_
TENSOR_MODEL_PARALLEL
_WORLD_SIZE
if
_MPU_
INTRA_LAYER
_WORLD_SIZE
is
not
None
:
if
_MPU_
TENSOR_MODEL_PARALLEL
_WORLD_SIZE
is
not
None
:
return
_MPU_
INTRA_LAYER
_WORLD_SIZE
return
_MPU_
TENSOR_MODEL_PARALLEL
_WORLD_SIZE
return
torch
.
distributed
.
get_world_size
(
group
=
get_
intra_laye
r_model_parallel_group
())
return
torch
.
distributed
.
get_world_size
(
group
=
get_
tenso
r_model_parallel_group
())
def
get_
inter_layer
_model_parallel_world_size
():
def
get_
pipeline
_model_parallel_world_size
():
"""Return world size for the
inter-layer
model parallel group."""
"""Return world size for the
pipeline
model parallel group."""
global
_MPU_
INTER_LAYER
_WORLD_SIZE
global
_MPU_
PIPELINE_MODEL_PARALLEL
_WORLD_SIZE
if
_MPU_
INTER_LAYER
_WORLD_SIZE
is
not
None
:
if
_MPU_
PIPELINE_MODEL_PARALLEL
_WORLD_SIZE
is
not
None
:
return
_MPU_
INTER_LAYER
_WORLD_SIZE
return
_MPU_
PIPELINE_MODEL_PARALLEL
_WORLD_SIZE
return
torch
.
distributed
.
get_world_size
(
group
=
get_
inter_layer
_model_parallel_group
())
return
torch
.
distributed
.
get_world_size
(
group
=
get_
pipeline
_model_parallel_group
())
def
set_
intra_laye
r_model_parallel_rank
(
rank
):
def
set_
tenso
r_model_parallel_rank
(
rank
):
"""Set
intra-laye
r model parallel rank."""
"""Set
tenso
r model parallel rank."""
global
_MPU_
INTRA_LAYER
_RANK
global
_MPU_
TENSOR_MODEL_PARALLEL
_RANK
_MPU_
INTRA_LAYER
_RANK
=
rank
_MPU_
TENSOR_MODEL_PARALLEL
_RANK
=
rank
def
set_
inter_layer
_model_parallel_rank
(
rank
):
def
set_
pipeline
_model_parallel_rank
(
rank
):
"""Set
inter-layer
model parallel rank."""
"""Set
pipeline
model parallel rank."""
global
_MPU_
INTER_LAYER
_RANK
global
_MPU_
PIPELINE_MODEL_PARALLEL
_RANK
_MPU_
INTER_LAYER
_RANK
=
rank
_MPU_
PIPELINE_MODEL_PARALLEL
_RANK
=
rank
def
get_
intra_laye
r_model_parallel_rank
():
def
get_
tenso
r_model_parallel_rank
():
"""Return my rank for the
intra-laye
r model parallel group."""
"""Return my rank for the
tenso
r model parallel group."""
global
_MPU_
INTRA_LAYER
_RANK
global
_MPU_
TENSOR_MODEL_PARALLEL
_RANK
if
_MPU_
INTRA_LAYER
_RANK
is
not
None
:
if
_MPU_
TENSOR_MODEL_PARALLEL
_RANK
is
not
None
:
return
_MPU_
INTRA_LAYER
_RANK
return
_MPU_
TENSOR_MODEL_PARALLEL
_RANK
return
torch
.
distributed
.
get_rank
(
group
=
get_
intra_laye
r_model_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_
tenso
r_model_parallel_group
())
def
get_
inter_layer
_model_parallel_rank
():
def
get_
pipeline
_model_parallel_rank
():
"""Return my rank for the
inter-layer
model parallel group."""
"""Return my rank for the
pipeline
model parallel group."""
global
_MPU_
INTER_LAYER
_RANK
global
_MPU_
PIPELINE_MODEL_PARALLEL
_RANK
if
_MPU_
INTER_LAYER
_RANK
is
not
None
:
if
_MPU_
PIPELINE_MODEL_PARALLEL
_RANK
is
not
None
:
return
_MPU_
INTER_LAYER
_RANK
return
_MPU_
PIPELINE_MODEL_PARALLEL
_RANK
return
torch
.
distributed
.
get_rank
(
group
=
get_
inter_layer
_model_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_
pipeline
_model_parallel_group
())
def
is_
inter_layer
_first_stage
():
def
is_
pipeline
_first_stage
():
"""Return True if in the first
inter-layer
model-parallel stage, False otherwise."""
"""Return True if in the first
pipeline
model-parallel stage, False otherwise."""
return
get_
inter_layer
_model_parallel_rank
()
==
0
return
get_
pipeline
_model_parallel_rank
()
==
0
def
is_
inter_layer
_last_stage
():
def
is_
pipeline
_last_stage
():
"""Return True if in the last
inter-layer
model-parallel stage, False otherwise."""
"""Return True if in the last
pipeline
model-parallel stage, False otherwise."""
return
get_
inter_layer
_model_parallel_rank
()
==
(
return
get_
pipeline
_model_parallel_rank
()
==
(
get_
inter_layer
_model_parallel_world_size
()
-
1
)
get_
pipeline
_model_parallel_world_size
()
-
1
)
def
get_
intra_laye
r_model_parallel_src_rank
():
def
get_
tenso
r_model_parallel_src_rank
():
"""Calculate the global rank corresponding to a local rank
"""Calculate the global rank corresponding to a local rank
in the
intra-laye
r model parallel group."""
in the
tenso
r model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
global_rank
=
torch
.
distributed
.
get_rank
()
local_world_size
=
get_
intra_laye
r_model_parallel_world_size
()
local_world_size
=
get_
tenso
r_model_parallel_world_size
()
return
(
global_rank
//
local_world_size
)
*
local_world_size
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_
inter_layer
_model_parallel_src_rank
():
def
get_
pipeline
_model_parallel_src_rank
():
"""Calculate the global rank corresponding to a local rank
"""Calculate the global rank corresponding to a local rank
in the
inter-layer
model parallel group."""
in the
pipeline
model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
global_rank
=
torch
.
distributed
.
get_rank
()
global_world_size
=
torch
.
distributed
.
get_world_size
()
global_world_size
=
torch
.
distributed
.
get_world_size
()
local_world_size
=
get_
inter_layer
_model_parallel_world_size
()
local_world_size
=
get_
pipeline
_model_parallel_world_size
()
return
global_rank
%
(
global_world_size
//
local_world_size
)
return
global_rank
%
(
global_world_size
//
local_world_size
)
...
@@ -293,9 +293,9 @@ def get_data_parallel_rank():
...
@@ -293,9 +293,9 @@ def get_data_parallel_rank():
def
destroy_model_parallel
():
def
destroy_model_parallel
():
"""Set the groups to none."""
"""Set the groups to none."""
global
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
global
_
TENSO
R_MODEL_PARALLEL_GROUP
_
INTRA_LAYE
R_MODEL_PARALLEL_GROUP
=
None
_
TENSO
R_MODEL_PARALLEL_GROUP
=
None
global
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
global
_
PIPELINE
_MODEL_PARALLEL_GROUP
_
INTER_LAYER
_MODEL_PARALLEL_GROUP
=
None
_
PIPELINE
_MODEL_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
megatron/mpu/layers.py
View file @
52a5f2f2
...
@@ -35,12 +35,12 @@ except Exception as e:
...
@@ -35,12 +35,12 @@ except Exception as e:
'instead of apex.normalization.FusedLayerNorm!'
)
'instead of apex.normalization.FusedLayerNorm!'
)
from
torch.nn
import
LayerNorm
from
torch.nn
import
LayerNorm
from
.initialize
import
get_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_rank
from
.initialize
import
get_
intra_laye
r_model_parallel_world_size
from
.initialize
import
get_
tenso
r_model_parallel_world_size
from
.mappings
import
copy_to_
intra_laye
r_model_parallel_region
from
.mappings
import
copy_to_
tenso
r_model_parallel_region
from
.mappings
import
gather_from_
intra_laye
r_model_parallel_region
from
.mappings
import
gather_from_
tenso
r_model_parallel_region
from
.mappings
import
reduce_from_
intra_laye
r_model_parallel_region
from
.mappings
import
reduce_from_
tenso
r_model_parallel_region
from
.mappings
import
scatter_to_
intra_laye
r_model_parallel_region
from
.mappings
import
scatter_to_
tenso
r_model_parallel_region
from
.random
import
get_cuda_rng_tracker
from
.random
import
get_cuda_rng_tracker
from
.utils
import
divide
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
split_tensor_along_last_dim
...
@@ -51,7 +51,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
...
@@ -51,7 +51,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
partition_dim
,
stride
=
1
):
partition_dim
,
stride
=
1
):
"""Initialize affine weight for model parallel on GPU."""
"""Initialize affine weight for model parallel on GPU."""
weight
.
intra_laye
r_model_parallel
=
True
weight
.
tenso
r_model_parallel
=
True
weight
.
partition_dim
=
partition_dim
weight
.
partition_dim
=
partition_dim
weight
.
partition_stride
=
stride
weight
.
partition_stride
=
stride
...
@@ -68,7 +68,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
...
@@ -68,7 +68,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
Build the master weight on all processes and scatter
Build the master weight on all processes and scatter
the relevant chunk."""
the relevant chunk."""
weight
.
intra_laye
r_model_parallel
=
True
weight
.
tenso
r_model_parallel
=
True
weight
.
partition_dim
=
partition_dim
weight
.
partition_dim
=
partition_dim
weight
.
partition_stride
=
stride
weight
.
partition_stride
=
stride
...
@@ -85,7 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
...
@@ -85,7 +85,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
weight_list
=
torch
.
split
(
master_weight
,
per_partition_per_stride_size
,
weight_list
=
torch
.
split
(
master_weight
,
per_partition_per_stride_size
,
dim
=
partition_dim
)
dim
=
partition_dim
)
rank
=
get_model_parallel_rank
()
rank
=
get_model_parallel_rank
()
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
my_weight_list
=
weight_list
[
rank
::
world_size
]
my_weight_list
=
weight_list
[
rank
::
world_size
]
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -119,12 +119,12 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -119,12 +119,12 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
scale_grad_by_freq
=
False
self
.
scale_grad_by_freq
=
False
self
.
sparse
=
False
self
.
sparse
=
False
self
.
_weight
=
None
self
.
_weight
=
None
self
.
intra_laye
r_model_parallel_size
=
get_
intra_laye
r_model_parallel_world_size
()
self
.
tenso
r_model_parallel_size
=
get_
tenso
r_model_parallel_world_size
()
# Divide the weight matrix along the vocaburaly dimension.
# Divide the weight matrix along the vocaburaly dimension.
self
.
vocab_start_index
,
self
.
vocab_end_index
=
\
self
.
vocab_start_index
,
self
.
vocab_end_index
=
\
VocabUtility
.
vocab_range_from_global_vocab_size
(
VocabUtility
.
vocab_range_from_global_vocab_size
(
self
.
num_embeddings
,
get_
intra_laye
r_model_parallel_rank
(),
self
.
num_embeddings
,
get_
tenso
r_model_parallel_rank
(),
self
.
intra_laye
r_model_parallel_size
)
self
.
tenso
r_model_parallel_size
)
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
\
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
\
self
.
vocab_start_index
self
.
vocab_start_index
...
@@ -145,7 +145,7 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -145,7 +145,7 @@ class VocabParallelEmbedding(torch.nn.Module):
partition_dim
=
0
,
stride
=
1
)
partition_dim
=
0
,
stride
=
1
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
if
self
.
intra_laye
r_model_parallel_size
>
1
:
if
self
.
tenso
r_model_parallel_size
>
1
:
# Build the mask.
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
\
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
\
(
input_
>=
self
.
vocab_end_index
)
(
input_
>=
self
.
vocab_end_index
)
...
@@ -160,10 +160,10 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -160,10 +160,10 @@ class VocabParallelEmbedding(torch.nn.Module):
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
self
.
sparse
)
# Mask the output embedding.
# Mask the output embedding.
if
self
.
intra_laye
r_model_parallel_size
>
1
:
if
self
.
tenso
r_model_parallel_size
>
1
:
output_parallel
[
input_mask
,
:]
=
0.0
output_parallel
[
input_mask
,
:]
=
0.0
# Reduce across all the model parallel GPUs.
# Reduce across all the model parallel GPUs.
output
=
reduce_from_
intra_laye
r_model_parallel_region
(
output_parallel
)
output
=
reduce_from_
tenso
r_model_parallel_region
(
output_parallel
)
return
output
return
output
...
@@ -202,7 +202,7 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -202,7 +202,7 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
output_size
=
output_size
self
.
output_size
=
output_size
self
.
gather_output
=
gather_output
self
.
gather_output
=
gather_output
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
world_size
)
self
.
output_size_per_partition
=
divide
(
output_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
skip_bias_add
=
skip_bias_add
...
@@ -235,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -235,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
output_size_per_partition
,
self
.
output_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
dtype
=
args
.
params_dtype
))
self
.
bias
.
intra_laye
r_model_parallel
=
True
self
.
bias
.
tenso
r_model_parallel
=
True
self
.
bias
.
partition_dim
=
0
self
.
bias
.
partition_dim
=
0
self
.
bias
.
stride
=
stride
self
.
bias
.
stride
=
stride
# Always initialize bias to zero.
# Always initialize bias to zero.
...
@@ -248,14 +248,14 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -248,14 +248,14 @@ class ColumnParallelLinear(torch.nn.Module):
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
# Set up backprop all-reduce.
input_parallel
=
copy_to_
intra_laye
r_model_parallel_region
(
input_
)
input_parallel
=
copy_to_
tenso
r_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
output
=
gather_from_
intra_laye
r_model_parallel_region
(
output_parallel
)
output
=
gather_from_
tenso
r_model_parallel_region
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
...
@@ -304,7 +304,7 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -304,7 +304,7 @@ class RowParallelLinear(torch.nn.Module):
self
.
output_size
=
output_size
self
.
output_size
=
output_size
self
.
input_is_parallel
=
input_is_parallel
self
.
input_is_parallel
=
input_is_parallel
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
skip_bias_add
=
skip_bias_add
...
@@ -348,11 +348,11 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -348,11 +348,11 @@ class RowParallelLinear(torch.nn.Module):
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
input_parallel
=
scatter_to_
intra_laye
r_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_
tenso
r_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
# All-reduce across all the partitions.
# All-reduce across all the partitions.
output_
=
reduce_from_
intra_laye
r_model_parallel_region
(
output_parallel
)
output_
=
reduce_from_
tenso
r_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
output_bias
=
None
...
...
megatron/mpu/mappings.py
View file @
52a5f2f2
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
torch
import
torch
from
.initialize
import
get_
intra_laye
r_model_parallel_group
,
get_
intra_laye
r_model_parallel_world_size
,
get_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_group
,
get_
tenso
r_model_parallel_world_size
,
get_
tenso
r_model_parallel_rank
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
split_tensor_along_last_dim
...
@@ -23,11 +23,11 @@ def _reduce(input_):
...
@@ -23,11 +23,11 @@ def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
"""All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
get_
intra_laye
r_model_parallel_world_size
()
==
1
:
if
get_
tenso
r_model_parallel_world_size
()
==
1
:
return
input_
return
input_
# All-reduce.
# All-reduce.
torch
.
distributed
.
all_reduce
(
input_
,
group
=
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
all_reduce
(
input_
,
group
=
get_
tenso
r_model_parallel_group
())
return
input_
return
input_
...
@@ -36,7 +36,7 @@ def _split(input_):
...
@@ -36,7 +36,7 @@ def _split(input_):
"""Split the tensor along its last dimension and keep the
"""Split the tensor along its last dimension and keep the
corresponding slice."""
corresponding slice."""
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
return
input_
...
@@ -45,7 +45,7 @@ def _split(input_):
...
@@ -45,7 +45,7 @@ def _split(input_):
input_list
=
split_tensor_along_last_dim
(
input_
,
world_size
)
input_list
=
split_tensor_along_last_dim
(
input_
,
world_size
)
# Note: torch.split does not create contiguous tensors by default.
# Note: torch.split does not create contiguous tensors by default.
rank
=
get_
intra_laye
r_model_parallel_rank
()
rank
=
get_
tenso
r_model_parallel_rank
()
output
=
input_list
[
rank
].
contiguous
()
output
=
input_list
[
rank
].
contiguous
()
return
output
return
output
...
@@ -54,18 +54,18 @@ def _split(input_):
...
@@ -54,18 +54,18 @@ def _split(input_):
def
_gather
(
input_
):
def
_gather
(
input_
):
"""Gather tensors and concatinate along the last dimension."""
"""Gather tensors and concatinate along the last dimension."""
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
return
input_
# Size and dimension.
# Size and dimension.
last_dim
=
input_
.
dim
()
-
1
last_dim
=
input_
.
dim
()
-
1
rank
=
get_
intra_laye
r_model_parallel_rank
()
rank
=
get_
tenso
r_model_parallel_rank
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
tensor_list
[
rank
]
=
input_
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
get_
tenso
r_model_parallel_group
())
# Note: torch.cat already creates a contiguous tensor.
# Note: torch.cat already creates a contiguous tensor.
output
=
torch
.
cat
(
tensor_list
,
dim
=
last_dim
).
contiguous
()
output
=
torch
.
cat
(
tensor_list
,
dim
=
last_dim
).
contiguous
()
...
@@ -141,17 +141,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
...
@@ -141,17 +141,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
# Helper functions.
# Helper functions.
# -----------------
# -----------------
def
copy_to_
intra_laye
r_model_parallel_region
(
input_
):
def
copy_to_
tenso
r_model_parallel_region
(
input_
):
return
_CopyToModelParallelRegion
.
apply
(
input_
)
return
_CopyToModelParallelRegion
.
apply
(
input_
)
def
reduce_from_
intra_laye
r_model_parallel_region
(
input_
):
def
reduce_from_
tenso
r_model_parallel_region
(
input_
):
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_
intra_laye
r_model_parallel_region
(
input_
):
def
scatter_to_
tenso
r_model_parallel_region
(
input_
):
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
def
gather_from_
intra_laye
r_model_parallel_region
(
input_
):
def
gather_from_
tenso
r_model_parallel_region
(
input_
):
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
megatron/mpu/random.py
View file @
52a5f2f2
...
@@ -28,13 +28,13 @@ from megatron import get_args
...
@@ -28,13 +28,13 @@ from megatron import get_args
from
megatron.memory
import
allocate_mem_buff
from
megatron.memory
import
allocate_mem_buff
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_
intra_laye
r_model_parallel_group
from
.initialize
import
get_
tenso
r_model_parallel_group
from
.initialize
import
get_
intra_laye
r_model_parallel_rank
from
.initialize
import
get_
tenso
r_model_parallel_rank
from
.initialize
import
get_
intra_laye
r_model_parallel_world_size
from
.initialize
import
get_
tenso
r_model_parallel_world_size
# Default name for the model parallel rng tracker.
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'
intra-laye
r-model-parallel-rng'
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'
tenso
r-model-parallel-rng'
# Whether apply model parallelsim to checkpointed hidden states.
# Whether apply model parallelsim to checkpointed hidden states.
...
@@ -104,15 +104,15 @@ def _set_cuda_rng_state(new_state, device=-1):
...
@@ -104,15 +104,15 @@ def _set_cuda_rng_state(new_state, device=-1):
def
split_tensor_into_1d_equal_chunks
(
tensor
):
def
split_tensor_into_1d_equal_chunks
(
tensor
):
"""Break a tensor into equal 1D chunks."""
"""Break a tensor into equal 1D chunks."""
data
=
tensor
.
view
(
-
1
)
data
=
tensor
.
view
(
-
1
)
partition_size
=
torch
.
numel
(
data
)
//
get_
intra_laye
r_model_parallel_world_size
()
partition_size
=
torch
.
numel
(
data
)
//
get_
tenso
r_model_parallel_world_size
()
start_index
=
partition_size
*
get_
intra_laye
r_model_parallel_rank
()
start_index
=
partition_size
*
get_
tenso
r_model_parallel_rank
()
end_index
=
start_index
+
partition_size
end_index
=
start_index
+
partition_size
return
data
[
start_index
:
end_index
]
return
data
[
start_index
:
end_index
]
def
gather_split_1d_tensor
(
tensor
):
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
"""Opposite of above function, gather values from model parallel ranks."""
world_size
=
get_
intra_laye
r_model_parallel_world_size
()
world_size
=
get_
tenso
r_model_parallel_world_size
()
numel
=
torch
.
numel
(
tensor
)
numel
=
torch
.
numel
(
tensor
)
numel_gathered
=
world_size
*
numel
numel_gathered
=
world_size
*
numel
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
...
@@ -120,7 +120,7 @@ def gather_split_1d_tensor(tensor):
...
@@ -120,7 +120,7 @@ def gather_split_1d_tensor(tensor):
requires_grad
=
False
)
requires_grad
=
False
)
chunks
=
[
gathered
[
i
*
numel
:(
i
+
1
)
*
numel
]
for
i
in
range
(
world_size
)]
chunks
=
[
gathered
[
i
*
numel
:(
i
+
1
)
*
numel
]
for
i
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
chunks
,
tensor
,
torch
.
distributed
.
all_gather
(
chunks
,
tensor
,
group
=
get_
intra_laye
r_model_parallel_group
())
group
=
get_
tenso
r_model_parallel_group
())
return
gathered
return
gathered
...
@@ -204,7 +204,7 @@ def get_cuda_rng_tracker():
...
@@ -204,7 +204,7 @@ def get_cuda_rng_tracker():
return
_CUDA_RNG_STATE_TRACKER
return
_CUDA_RNG_STATE_TRACKER
def
intra_layer_
model_parallel_cuda_manual_seed
(
seed
):
def
model_parallel_cuda_manual_seed
(
seed
):
"""Initialize model parallel cuda seed.
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
This function should be called after the model parallel is
...
@@ -215,15 +215,15 @@ def intra_layer_model_parallel_cuda_manual_seed(seed):
...
@@ -215,15 +215,15 @@ def intra_layer_model_parallel_cuda_manual_seed(seed):
default state: This is for data parallelism and is the same among a
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
set of model parallel GPUs but different across
different model paralle groups. This is used for
different model paralle groups. This is used for
example for dropout in the non-
intra-laye
r-model-parallel regions.
example for dropout in the non-
tenso
r-model-parallel regions.
intra-laye
r-model-parallel state: This state is different among a set of model
tenso
r-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
groups. This is used for example for dropout in
model parallel regions.
model parallel regions.
"""
"""
# 2718 is just for fun and any POSITIVE value will work.
# 2718 is just for fun and any POSITIVE value will work.
offset
=
seed
+
2718
offset
=
seed
+
2718
intra_laye
r_model_parallel_seed
=
offset
+
get_
intra_laye
r_model_parallel_rank
()
tenso
r_model_parallel_seed
=
offset
+
get_
tenso
r_model_parallel_rank
()
# Data parallel gets the original sedd.
# Data parallel gets the original sedd.
data_parallel_seed
=
seed
data_parallel_seed
=
seed
...
@@ -231,15 +231,15 @@ def intra_layer_model_parallel_cuda_manual_seed(seed):
...
@@ -231,15 +231,15 @@ def intra_layer_model_parallel_cuda_manual_seed(seed):
print
(
'> initializing model parallel cuda seeds on global rank {}, '
print
(
'> initializing model parallel cuda seeds on global rank {}, '
'model parallel rank {}, and data parallel rank {} with '
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'
.
format
(
'model parallel seed: {} and data parallel seed: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
get_
intra_laye
r_model_parallel_rank
(),
torch
.
distributed
.
get_rank
(),
get_
tenso
r_model_parallel_rank
(),
get_data_parallel_rank
(),
intra_laye
r_model_parallel_seed
,
get_data_parallel_rank
(),
tenso
r_model_parallel_seed
,
data_parallel_seed
),
flush
=
True
)
data_parallel_seed
),
flush
=
True
)
_CUDA_RNG_STATE_TRACKER
.
reset
()
_CUDA_RNG_STATE_TRACKER
.
reset
()
# Set the default state.
# Set the default state.
torch
.
cuda
.
manual_seed
(
data_parallel_seed
)
torch
.
cuda
.
manual_seed
(
data_parallel_seed
)
# and model parallel state.
# and model parallel state.
_CUDA_RNG_STATE_TRACKER
.
add
(
_MODEL_PARALLEL_RNG_TRACKER_NAME
,
_CUDA_RNG_STATE_TRACKER
.
add
(
_MODEL_PARALLEL_RNG_TRACKER_NAME
,
intra_laye
r_model_parallel_seed
)
tenso
r_model_parallel_seed
)
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
...
...
megatron/mpu/tests/commons.py
View file @
52a5f2f2
...
@@ -36,7 +36,7 @@ def set_random_seed(seed):
...
@@ -36,7 +36,7 @@ def set_random_seed(seed):
random
.
seed
(
seed
)
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
mpu
.
intra_layer_
model_parallel_cuda_manual_seed
(
seed
)
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
def
initialize_distributed
(
backend
=
'nccl'
):
def
initialize_distributed
(
backend
=
'nccl'
):
...
...
megatron/mpu/tests/test_cross_entropy.py
View file @
52a5f2f2
...
@@ -47,7 +47,7 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
...
@@ -47,7 +47,7 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
identity
=
IdentityLayer
((
batch_size
,
seq_length
,
vocab_size
),
identity
=
IdentityLayer
((
batch_size
,
seq_length
,
vocab_size
),
scale
=
logits_scale
).
cuda
()
scale
=
logits_scale
).
cuda
()
logits
=
identity
()
logits
=
identity
()
logits_parallel
=
mpu
.
scatter_to_
intra_laye
r_model_parallel_region
(
logits
)
logits_parallel
=
mpu
.
scatter_to_
tenso
r_model_parallel_region
(
logits
)
target
=
torch
.
cuda
.
LongTensor
(
target
=
torch
.
cuda
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
)
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
)
loss
=
vocab_parallel_cross_entropy
(
logits_parallel
,
target
).
mean
()
loss
=
vocab_parallel_cross_entropy
(
logits_parallel
,
target
).
mean
()
...
@@ -55,20 +55,20 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
...
@@ -55,20 +55,20 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size,
return
loss
,
identity
.
weight
.
grad
return
loss
,
identity
.
weight
.
grad
def
test_cross_entropy
(
intra_laye
r_model_parallel_size
):
def
test_cross_entropy
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing cross entropy with model parallel size {} ...'
.
print
(
'> testing cross entropy with model parallel size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
format
(
tenso
r_model_parallel_size
))
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
batch_size
=
13
batch_size
=
13
seq_length
=
17
seq_length
=
17
vocab_size_per_partition
=
11
vocab_size_per_partition
=
11
logits_scale
=
1000.0
logits_scale
=
1000.0
vocab_size
=
vocab_size_per_partition
*
intra_laye
r_model_parallel_size
vocab_size
=
vocab_size_per_partition
*
tenso
r_model_parallel_size
seed
=
1234
seed
=
1234
loss_torch
,
grad_torch
=
torch_cross_entropy
(
batch_size
,
seq_length
,
loss_torch
,
grad_torch
=
torch_cross_entropy
(
batch_size
,
seq_length
,
...
@@ -89,7 +89,7 @@ def test_cross_entropy(intra_layer_model_parallel_size):
...
@@ -89,7 +89,7 @@ def test_cross_entropy(intra_layer_model_parallel_size):
assert
error
<
1.0e-6
assert
error
<
1.0e-6
# Reset groups
# Reset groups
mpu
.
destroy_
intra_laye
r_model_parallel
()
mpu
.
destroy_
tenso
r_model_parallel
()
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
...
@@ -101,8 +101,8 @@ if __name__ == '__main__':
...
@@ -101,8 +101,8 @@ if __name__ == '__main__':
initialize_distributed
()
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
intra_laye
r_model_parallel_size
=
1
tenso
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test cross entropy'
)
print_separator
(
'test cross entropy'
)
test_cross_entropy
(
intra_laye
r_model_parallel_size
)
test_cross_entropy
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
*=
2
megatron/mpu/tests/test_data.py
View file @
52a5f2f2
...
@@ -24,15 +24,15 @@ import sys
...
@@ -24,15 +24,15 @@ import sys
sys
.
path
.
append
(
"../.."
)
sys
.
path
.
append
(
"../.."
)
def
test_broadcast_data
(
intra_laye
r_model_parallel_size
):
def
test_broadcast_data
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing broadcast_data with model parallel size {} ...'
.
print
(
'> testing broadcast_data with model parallel size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
format
(
tenso
r_model_parallel_size
))
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
torch
.
manual_seed
(
1234
+
mpu
.
get_data_parallel_rank
())
torch
.
manual_seed
(
1234
+
mpu
.
get_data_parallel_rank
())
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
key_size_t
=
{
'key1'
:
[
7
,
11
],
key_size_t
=
{
'key1'
:
[
7
,
11
],
'key2'
:
[
8
,
2
,
1
],
'key2'
:
[
8
,
2
,
1
],
...
@@ -48,7 +48,7 @@ def test_broadcast_data(intra_layer_model_parallel_size):
...
@@ -48,7 +48,7 @@ def test_broadcast_data(intra_layer_model_parallel_size):
data_t
[
key
]
=
data
[
key
].
clone
()
data_t
[
key
]
=
data
[
key
].
clone
()
data
[
'keyX'
]
=
torch
.
FloatTensor
(
size
=
(
5
,
)).
random_
(
0
,
1000
)
data
[
'keyX'
]
=
torch
.
FloatTensor
(
size
=
(
5
,
)).
random_
(
0
,
1000
)
data_t
[
'keyX'
]
=
data
[
'keyX'
].
clone
()
data_t
[
'keyX'
]
=
data
[
'keyX'
].
clone
()
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
!=
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
!=
0
:
data
=
None
data
=
None
data_utils
.
_check_data_types
(
keys
,
data_t
,
torch
.
int64
)
data_utils
.
_check_data_types
(
keys
,
data_t
,
torch
.
int64
)
...
@@ -69,7 +69,7 @@ def test_broadcast_data(intra_layer_model_parallel_size):
...
@@ -69,7 +69,7 @@ def test_broadcast_data(intra_layer_model_parallel_size):
assert
data_b
[
key
].
sub
(
tensor
).
abs
().
max
()
==
0
assert
data_b
[
key
].
sub
(
tensor
).
abs
().
max
()
==
0
# Reset groups
# Reset groups
mpu
.
destroy_
intra_laye
r_model_parallel
()
mpu
.
destroy_
tenso
r_model_parallel
()
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
...
@@ -81,8 +81,8 @@ if __name__ == '__main__':
...
@@ -81,8 +81,8 @@ if __name__ == '__main__':
initialize_distributed
()
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
intra_laye
r_model_parallel_size
=
1
tenso
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test test broadcast data'
)
print_separator
(
'test test broadcast data'
)
test_broadcast_data
(
intra_laye
r_model_parallel_size
)
test_broadcast_data
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
*=
2
megatron/mpu/tests/test_initialize.py
View file @
52a5f2f2
...
@@ -21,15 +21,15 @@ import sys
...
@@ -21,15 +21,15 @@ import sys
sys
.
path
.
append
(
"../.."
)
sys
.
path
.
append
(
"../.."
)
def
test_initialize_model_parallel
(
intra_laye
r_model_parallel_size
):
def
test_initialize_model_parallel
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing initialize_model_parallel with size {} ...'
.
format
(
print
(
'> testing initialize_model_parallel with size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
tenso
r_model_parallel_size
))
intra_laye
r_model_parallel_size_
=
min
(
intra_laye
r_model_parallel_size
,
tenso
r_model_parallel_size_
=
min
(
tenso
r_model_parallel_size
,
torch
.
distributed
.
get_world_size
())
torch
.
distributed
.
get_world_size
())
assert
not
mpu
.
model_parallel_is_initialized
()
assert
not
mpu
.
model_parallel_is_initialized
()
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size_
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size_
)
assert
mpu
.
model_parallel_is_initialized
()
assert
mpu
.
model_parallel_is_initialized
()
# Checks.
# Checks.
...
@@ -38,15 +38,15 @@ def test_initialize_model_parallel(intra_layer_model_parallel_size):
...
@@ -38,15 +38,15 @@ def test_initialize_model_parallel(intra_layer_model_parallel_size):
assert
rank
==
torch
.
distributed
.
get_rank
(
group
=
group
)
assert
rank
==
torch
.
distributed
.
get_rank
(
group
=
group
)
# Model parallel.
# Model parallel.
world_size
=
intra_laye
r_model_parallel_size_
world_size
=
tenso
r_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
%
intra_laye
r_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
%
tenso
r_model_parallel_size_
assert
world_size
==
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
assert
world_size
==
mpu
.
get_
tenso
r_model_parallel_world_size
()
assert
rank
==
mpu
.
get_
intra_laye
r_model_parallel_rank
()
assert
rank
==
mpu
.
get_
tenso
r_model_parallel_rank
()
check
(
mpu
.
get_
intra_laye
r_model_parallel_group
(),
world_size
,
rank
)
check
(
mpu
.
get_
tenso
r_model_parallel_group
(),
world_size
,
rank
)
# Data parallel.
# Data parallel.
world_size
=
torch
.
distributed
.
get_world_size
()
//
intra_laye
r_model_parallel_size_
world_size
=
torch
.
distributed
.
get_world_size
()
//
tenso
r_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
//
intra_laye
r_model_parallel_size
rank
=
torch
.
distributed
.
get_rank
()
//
tenso
r_model_parallel_size
assert
world_size
==
mpu
.
get_data_parallel_world_size
()
assert
world_size
==
mpu
.
get_data_parallel_world_size
()
assert
rank
==
mpu
.
get_data_parallel_rank
()
assert
rank
==
mpu
.
get_data_parallel_rank
()
check
(
mpu
.
get_data_parallel_group
(),
world_size
,
rank
)
check
(
mpu
.
get_data_parallel_group
(),
world_size
,
rank
)
...
@@ -59,20 +59,20 @@ def test_initialize_model_parallel(intra_layer_model_parallel_size):
...
@@ -59,20 +59,20 @@ def test_initialize_model_parallel(intra_layer_model_parallel_size):
print
(
'>> passed the test :-)'
)
print
(
'>> passed the test :-)'
)
def
test_get_
intra_laye
r_model_parallel_src_rank
(
intra_laye
r_model_parallel_size_
):
def
test_get_
tenso
r_model_parallel_src_rank
(
tenso
r_model_parallel_size_
):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing get_
intra_laye
r_model_parallel_src_rank with size {} ...'
.
format
(
print
(
'> testing get_
tenso
r_model_parallel_src_rank with size {} ...'
.
format
(
intra_laye
r_model_parallel_size_
))
tenso
r_model_parallel_size_
))
intra_laye
r_model_parallel_size
=
min
(
intra_laye
r_model_parallel_size_
,
tenso
r_model_parallel_size
=
min
(
tenso
r_model_parallel_size_
,
torch
.
distributed
.
get_world_size
())
torch
.
distributed
.
get_world_size
())
assert
not
mpu
.
model_parallel_is_initialized
()
assert
not
mpu
.
model_parallel_is_initialized
()
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
assert
mpu
.
model_parallel_is_initialized
()
assert
mpu
.
model_parallel_is_initialized
()
# Checks
# Checks
src_rank
=
torch
.
distributed
.
get_rank
()
-
mpu
.
get_
intra_laye
r_model_parallel_rank
()
src_rank
=
torch
.
distributed
.
get_rank
()
-
mpu
.
get_
tenso
r_model_parallel_rank
()
assert
mpu
.
get_
intra_laye
r_model_parallel_src_rank
()
==
src_rank
assert
mpu
.
get_
tenso
r_model_parallel_src_rank
()
==
src_rank
# Reset groups
# Reset groups
mpu
.
destroy_model_parallel
()
mpu
.
destroy_model_parallel
()
...
@@ -86,10 +86,10 @@ if __name__ == '__main__':
...
@@ -86,10 +86,10 @@ if __name__ == '__main__':
initialize_distributed
()
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
intra_laye
r_model_parallel_size
=
1
tenso
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test initialize model parallel'
)
print_separator
(
'test initialize model parallel'
)
test_initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
test_initialize_model_parallel
(
tenso
r_model_parallel_size
)
print_separator
(
'test model parallel source rank'
)
print_separator
(
'test model parallel source rank'
)
test_get_
intra_laye
r_model_parallel_src_rank
(
intra_laye
r_model_parallel_size
)
test_get_
tenso
r_model_parallel_src_rank
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
*=
2
megatron/mpu/tests/test_layers.py
View file @
52a5f2f2
...
@@ -26,14 +26,14 @@ import sys
...
@@ -26,14 +26,14 @@ import sys
sys
.
path
.
append
(
"../.."
)
sys
.
path
.
append
(
"../.."
)
def
test_parallel_embedding
(
intra_laye
r_model_parallel_size
):
def
test_parallel_embedding
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing parallel embedding with model parallel size {} ...'
.
print
(
'> testing parallel embedding with model parallel size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
format
(
tenso
r_model_parallel_size
))
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
batch_size
=
17
batch_size
=
17
seq_length
=
23
seq_length
=
23
...
@@ -80,16 +80,16 @@ def test_parallel_embedding(intra_layer_model_parallel_size):
...
@@ -80,16 +80,16 @@ def test_parallel_embedding(intra_layer_model_parallel_size):
assert
error
<
1.0e-12
,
'error: {}'
.
format
(
error
)
assert
error
<
1.0e-12
,
'error: {}'
.
format
(
error
)
weight_grad_orig
=
torch
.
split
(
embedding_original
.
weight
.
grad
,
weight_grad_orig
=
torch
.
split
(
embedding_original
.
weight
.
grad
,
hidden_size
//
intra_laye
r_model_parallel_size
,
hidden_size
//
tenso
r_model_parallel_size
,
1
)[
mpu
.
get_
intra_laye
r_model_parallel_rank
()]
1
)[
mpu
.
get_
tenso
r_model_parallel_rank
()]
error
=
embedding_parallel
.
weight
.
grad
.
sub
(
weight_grad_orig
).
abs
().
max
()
error
=
embedding_parallel
.
weight
.
grad
.
sub
(
weight_grad_orig
).
abs
().
max
()
print
(
' error in grad (parallel) on global rank {}: {}'
.
format
(
print
(
' error in grad (parallel) on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
'error: {}'
.
format
(
error
)
assert
error
<
1.0e-12
,
'error: {}'
.
format
(
error
)
weight_grad_orig
=
torch
.
split
(
embedding_original
.
weight
.
grad
,
weight_grad_orig
=
torch
.
split
(
embedding_original
.
weight
.
grad
,
vocab_size
//
intra_laye
r_model_parallel_size
,
vocab_size
//
tenso
r_model_parallel_size
,
0
)[
mpu
.
get_
intra_laye
r_model_parallel_rank
()]
0
)[
mpu
.
get_
tenso
r_model_parallel_rank
()]
error
=
embedding_vocab_parallel
.
weight
.
grad
.
sub
(
error
=
embedding_vocab_parallel
.
weight
.
grad
.
sub
(
weight_grad_orig
).
abs
().
max
()
weight_grad_orig
).
abs
().
max
()
print
(
' error in grad (vocab parallel) on global rank {}: {}'
.
format
(
print
(
' error in grad (vocab parallel) on global rank {}: {}'
.
format
(
...
@@ -104,19 +104,19 @@ def test_parallel_embedding(intra_layer_model_parallel_size):
...
@@ -104,19 +104,19 @@ def test_parallel_embedding(intra_layer_model_parallel_size):
print
(
'>> passed the test :-)'
)
print
(
'>> passed the test :-)'
)
def
test_initialize_affine_weight
(
intra_laye
r_model_parallel_size
):
def
test_initialize_affine_weight
(
tenso
r_model_parallel_size
):
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing initialize_affine_weight with model parallel '
print
(
'> testing initialize_affine_weight with model parallel '
'size: {}'
.
format
(
intra_laye
r_model_parallel_size
))
'size: {}'
.
format
(
tenso
r_model_parallel_size
))
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
seed
=
12345
seed
=
12345
input_size_coeff
=
13
input_size_coeff
=
13
input_size
=
input_size_coeff
*
intra_laye
r_model_parallel_size
input_size
=
input_size_coeff
*
tenso
r_model_parallel_size
output_size_coeff
=
17
output_size_coeff
=
17
output_size
=
output_size_coeff
*
intra_laye
r_model_parallel_size
output_size
=
output_size_coeff
*
tenso
r_model_parallel_size
# ---------------
# ---------------
# Column parallel
# Column parallel
...
@@ -131,7 +131,7 @@ def test_initialize_affine_weight(intra_layer_model_parallel_size):
...
@@ -131,7 +131,7 @@ def test_initialize_affine_weight(intra_layer_model_parallel_size):
set_random_seed
(
seed
)
set_random_seed
(
seed
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
)
torch
.
nn
.
init
.
normal_
(
master_weight
)
torch
.
nn
.
init
.
normal_
(
master_weight
)
rank
=
mpu
.
get_
intra_laye
r_model_parallel_rank
()
rank
=
mpu
.
get_
tenso
r_model_parallel_rank
()
my_weight
=
torch
.
split
(
master_weight
,
output_size_coeff
,
my_weight
=
torch
.
split
(
master_weight
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
dim
=
0
)[
rank
].
contiguous
().
clone
()
...
@@ -154,7 +154,7 @@ def test_initialize_affine_weight(intra_layer_model_parallel_size):
...
@@ -154,7 +154,7 @@ def test_initialize_affine_weight(intra_layer_model_parallel_size):
set_random_seed
(
seed
)
set_random_seed
(
seed
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
)
torch
.
nn
.
init
.
normal_
(
master_weight
)
torch
.
nn
.
init
.
normal_
(
master_weight
)
rank
=
mpu
.
get_
intra_laye
r_model_parallel_rank
()
rank
=
mpu
.
get_
tenso
r_model_parallel_rank
()
my_weight
=
torch
.
split
(
master_weight
,
input_size_coeff
,
my_weight
=
torch
.
split
(
master_weight
,
input_size_coeff
,
dim
=
1
)[
rank
].
contiguous
().
clone
()
dim
=
1
)[
rank
].
contiguous
().
clone
()
...
@@ -183,20 +183,20 @@ class IdentityLayer2D(torch.nn.Module):
...
@@ -183,20 +183,20 @@ class IdentityLayer2D(torch.nn.Module):
return
self
.
weight
return
self
.
weight
def
test_column_parallel_linear
(
intra_laye
r_model_parallel_size
):
def
test_column_parallel_linear
(
tenso
r_model_parallel_size
):
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing ColumnParallelLinear with model parallel '
print
(
'> testing ColumnParallelLinear with model parallel '
'size: {}'
.
format
(
intra_laye
r_model_parallel_size
))
'size: {}'
.
format
(
tenso
r_model_parallel_size
))
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
seed
=
12345
seed
=
12345
set_random_seed
(
seed
)
set_random_seed
(
seed
)
input_size_coeff
=
13
input_size_coeff
=
13
input_size
=
input_size_coeff
*
intra_laye
r_model_parallel_size
input_size
=
input_size_coeff
*
tenso
r_model_parallel_size
output_size_coeff
=
17
output_size_coeff
=
17
output_size
=
output_size_coeff
*
intra_laye
r_model_parallel_size
output_size
=
output_size_coeff
*
tenso
r_model_parallel_size
batch_size
=
7
batch_size
=
7
# Network
# Network
...
@@ -219,7 +219,7 @@ def test_column_parallel_linear(intra_layer_model_parallel_size):
...
@@ -219,7 +219,7 @@ def test_column_parallel_linear(intra_layer_model_parallel_size):
dLdb
=
torch
.
matmul
(
torch
.
ones
(
batch_size
,
1
).
cuda
().
t
(),
dLdY
).
view
(
-
1
)
dLdb
=
torch
.
matmul
(
torch
.
ones
(
batch_size
,
1
).
cuda
().
t
(),
dLdY
).
view
(
-
1
)
dLdX
=
torch
.
matmul
(
dLdY
,
A
)
dLdX
=
torch
.
matmul
(
dLdY
,
A
)
rank
=
mpu
.
get_
intra_laye
r_model_parallel_rank
()
rank
=
mpu
.
get_
tenso
r_model_parallel_rank
()
my_dLdA
=
torch
.
split
(
dLdA
,
output_size_coeff
,
my_dLdA
=
torch
.
split
(
dLdA
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
dim
=
0
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdA
.
sub
(
linear_layer
.
weight
.
grad
).
abs
().
max
()
error
=
my_dLdA
.
sub
(
linear_layer
.
weight
.
grad
).
abs
().
max
()
...
@@ -250,20 +250,20 @@ def test_column_parallel_linear(intra_layer_model_parallel_size):
...
@@ -250,20 +250,20 @@ def test_column_parallel_linear(intra_layer_model_parallel_size):
print
(
' >> passed the test :-)'
)
print
(
' >> passed the test :-)'
)
def
test_row_parallel_linear
(
intra_laye
r_model_parallel_size
):
def
test_row_parallel_linear
(
tenso
r_model_parallel_size
):
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing RowParallelLinear with model parallel '
print
(
'> testing RowParallelLinear with model parallel '
'size: {}'
.
format
(
intra_laye
r_model_parallel_size
))
'size: {}'
.
format
(
tenso
r_model_parallel_size
))
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
seed
=
12345
seed
=
12345
set_random_seed
(
seed
)
set_random_seed
(
seed
)
input_size_coeff
=
13
input_size_coeff
=
13
input_size
=
input_size_coeff
*
intra_laye
r_model_parallel_size
input_size
=
input_size_coeff
*
tenso
r_model_parallel_size
output_size_coeff
=
17
output_size_coeff
=
17
output_size
=
output_size_coeff
*
intra_laye
r_model_parallel_size
output_size
=
output_size_coeff
*
tenso
r_model_parallel_size
batch_size
=
7
batch_size
=
7
# Network
# Network
...
@@ -286,7 +286,7 @@ def test_row_parallel_linear(intra_layer_model_parallel_size):
...
@@ -286,7 +286,7 @@ def test_row_parallel_linear(intra_layer_model_parallel_size):
dLdb
=
torch
.
matmul
(
torch
.
ones
(
batch_size
,
1
).
cuda
().
t
(),
dLdY
).
view
(
-
1
)
dLdb
=
torch
.
matmul
(
torch
.
ones
(
batch_size
,
1
).
cuda
().
t
(),
dLdY
).
view
(
-
1
)
dLdX
=
torch
.
matmul
(
dLdY
,
A
)
dLdX
=
torch
.
matmul
(
dLdY
,
A
)
rank
=
mpu
.
get_
intra_laye
r_model_parallel_rank
()
rank
=
mpu
.
get_
tenso
r_model_parallel_rank
()
my_dLdA
=
torch
.
split
(
dLdA
,
input_size_coeff
,
my_dLdA
=
torch
.
split
(
dLdA
,
input_size_coeff
,
dim
=
1
)[
rank
].
contiguous
().
clone
()
dim
=
1
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdA
.
sub
(
linear_layer
.
weight
.
grad
).
abs
().
max
()
error
=
my_dLdA
.
sub
(
linear_layer
.
weight
.
grad
).
abs
().
max
()
...
@@ -325,11 +325,11 @@ class IdentityLayer3D(torch.nn.Module):
...
@@ -325,11 +325,11 @@ class IdentityLayer3D(torch.nn.Module):
return
self
.
weight
return
self
.
weight
def
parallel_self_attention
(
intra_laye
r_model_parallel_size
,
num_att_heads_per_partition
,
def
parallel_self_attention
(
tenso
r_model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
):
sequence_length
):
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
seed
=
12345
seed
=
12345
set_random_seed
(
seed
)
set_random_seed
(
seed
)
...
@@ -352,17 +352,17 @@ def parallel_self_attention(intra_layer_model_parallel_size, num_att_heads_per_p
...
@@ -352,17 +352,17 @@ def parallel_self_attention(intra_layer_model_parallel_size, num_att_heads_per_p
# Backward
# Backward
loss
.
backward
()
loss
.
backward
()
rank
=
mpu
.
get_
intra_laye
r_model_parallel_rank
()
rank
=
mpu
.
get_
tenso
r_model_parallel_rank
()
mpu
.
destroy_model_parallel
()
mpu
.
destroy_model_parallel
()
return
rank
,
hidden_size
,
intra_laye
r_model_parallel_size
,
loss
,
\
return
rank
,
hidden_size
,
tenso
r_model_parallel_size
,
loss
,
\
attention_layer
,
identity_layer
attention_layer
,
identity_layer
def
test_parallel_self_attention
(
intra_laye
r_model_parallel_size
):
def
test_parallel_self_attention
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing ParallelSelfAttention with model parallel '
print
(
'> testing ParallelSelfAttention with model parallel '
'size: {}'
.
format
(
intra_laye
r_model_parallel_size
))
'size: {}'
.
format
(
tenso
r_model_parallel_size
))
num_att_heads_per_partition
=
3
num_att_heads_per_partition
=
3
hidden_size_per_att_head
=
7
hidden_size_per_att_head
=
7
...
@@ -370,14 +370,14 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
...
@@ -370,14 +370,14 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
batch_size
=
5
batch_size
=
5
sequence_length
=
13
sequence_length
=
13
rank_1
,
hideen_size_1
,
intra_laye
r_model_parallel_size_1
,
loss_1
,
\
rank_1
,
hideen_size_1
,
tenso
r_model_parallel_size_1
,
loss_1
,
\
attention_layer_1
,
identity_layer_1
=
parallel_self_attention
(
attention_layer_1
,
identity_layer_1
=
parallel_self_attention
(
1
,
num_att_heads_per_partition
,
1
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
rank
,
hidden_size
,
intra_laye
r_model_parallel_size
,
loss
,
\
rank
,
hidden_size
,
tenso
r_model_parallel_size
,
loss
,
\
attention_layer
,
identity_layer
=
parallel_self_attention
(
attention_layer
,
identity_layer
=
parallel_self_attention
(
intra_laye
r_model_parallel_size
,
num_att_heads_per_partition
,
tenso
r_model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
assert
hideen_size_1
==
hidden_size
assert
hideen_size_1
==
hidden_size
...
@@ -389,7 +389,7 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
...
@@ -389,7 +389,7 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
my_lin_grad_list
=
torch
.
split
(
my_lin_grad_list
=
torch
.
split
(
attention_layer_1
.
query_key_value
.
weight
.
grad
,
attention_layer_1
.
query_key_value
.
weight
.
grad
,
hidden_size
//
intra_laye
r_model_parallel_size
,
0
)[
rank
::
intra_laye
r_model_parallel_size
]
hidden_size
//
tenso
r_model_parallel_size
,
0
)[
rank
::
tenso
r_model_parallel_size
]
my_lin_grad
=
torch
.
cat
(
my_lin_grad_list
,
dim
=
0
)
my_lin_grad
=
torch
.
cat
(
my_lin_grad_list
,
dim
=
0
)
error
=
my_lin_grad
.
sub
(
error
=
my_lin_grad
.
sub
(
attention_layer
.
query_key_value
.
weight
.
grad
).
abs
().
max
()
attention_layer
.
query_key_value
.
weight
.
grad
).
abs
().
max
()
...
@@ -410,11 +410,11 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
...
@@ -410,11 +410,11 @@ def test_parallel_self_attention(intra_layer_model_parallel_size):
print
(
' >> passed the test :-)'
)
print
(
' >> passed the test :-)'
)
def
parallel_transformer
(
intra_laye
r_model_parallel_size
,
num_att_heads_per_partition
,
def
parallel_transformer
(
tenso
r_model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
batch_size
,
sequence_length
):
hidden_size_per_att_head
,
batch_size
,
sequence_length
):
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
seed
=
12345
seed
=
12345
set_random_seed
(
seed
)
set_random_seed
(
seed
)
...
@@ -440,31 +440,31 @@ def parallel_transformer(intra_layer_model_parallel_size, num_att_heads_per_part
...
@@ -440,31 +440,31 @@ def parallel_transformer(intra_layer_model_parallel_size, num_att_heads_per_part
# Backward
# Backward
loss
.
backward
()
loss
.
backward
()
rank
=
mpu
.
get_
intra_laye
r_model_parallel_rank
()
rank
=
mpu
.
get_
tenso
r_model_parallel_rank
()
mpu
.
destroy_model_parallel
()
mpu
.
destroy_model_parallel
()
return
rank
,
hidden_size
,
intra_laye
r_model_parallel_size
,
loss
,
\
return
rank
,
hidden_size
,
tenso
r_model_parallel_size
,
loss
,
\
transformer_layer
,
identity_layer
transformer_layer
,
identity_layer
def
test_parallel_transformer_layer
(
intra_laye
r_model_parallel_size
):
def
test_parallel_transformer_layer
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing ParallelTransformerLayer with model parallel '
print
(
'> testing ParallelTransformerLayer with model parallel '
'size: {}'
.
format
(
intra_laye
r_model_parallel_size
))
'size: {}'
.
format
(
tenso
r_model_parallel_size
))
num_att_heads_per_partition
=
3
num_att_heads_per_partition
=
3
hidden_size_per_att_head
=
7
hidden_size_per_att_head
=
7
batch_size
=
5
batch_size
=
5
sequence_length
=
13
sequence_length
=
13
rank_1
,
hidden_size_1
,
intra_laye
r_model_parallel_size_1
,
loss_1
,
\
rank_1
,
hidden_size_1
,
tenso
r_model_parallel_size_1
,
loss_1
,
\
transformer_layer_1
,
identity_layer_1
=
parallel_transformer
(
transformer_layer_1
,
identity_layer_1
=
parallel_transformer
(
1
,
num_att_heads_per_partition
,
1
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
batch_size
,
sequence_length
)
hidden_size_per_att_head
,
batch_size
,
sequence_length
)
rank
,
hidden_size
,
intra_laye
r_model_parallel_size
,
loss
,
\
rank
,
hidden_size
,
tenso
r_model_parallel_size
,
loss
,
\
transformer_layer
,
identity_layer
=
parallel_transformer
(
transformer_layer
,
identity_layer
=
parallel_transformer
(
intra_laye
r_model_parallel_size
,
num_att_heads_per_partition
,
tenso
r_model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
batch_size
,
sequence_length
)
hidden_size_per_att_head
,
batch_size
,
sequence_length
)
error
=
loss_1
.
sub
(
loss
).
abs
().
max
()
error
=
loss_1
.
sub
(
loss
).
abs
().
max
()
...
@@ -494,37 +494,37 @@ if __name__ == '__main__':
...
@@ -494,37 +494,37 @@ if __name__ == '__main__':
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
print_separator
(
'test initialize affine weight'
)
print_separator
(
'test initialize affine weight'
)
intra_laye
r_model_parallel_size
=
1
tenso
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
while
tenso
r_model_parallel_size
<=
world_size
:
test_initialize_affine_weight
(
intra_laye
r_model_parallel_size
)
test_initialize_affine_weight
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
*=
2
intra_laye
r_model_parallel_size
=
1
tenso
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test parallel embedding'
)
print_separator
(
'test parallel embedding'
)
test_parallel_embedding
(
intra_laye
r_model_parallel_size
)
test_parallel_embedding
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
*=
2
print_separator
(
'test column-parallel linear'
)
print_separator
(
'test column-parallel linear'
)
intra_laye
r_model_parallel_size
=
1
tenso
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
while
tenso
r_model_parallel_size
<=
world_size
:
test_column_parallel_linear
(
intra_laye
r_model_parallel_size
)
test_column_parallel_linear
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
*=
2
print_separator
(
'test row-parallel linear'
)
print_separator
(
'test row-parallel linear'
)
intra_laye
r_model_parallel_size
=
1
tenso
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
while
tenso
r_model_parallel_size
<=
world_size
:
test_row_parallel_linear
(
intra_laye
r_model_parallel_size
)
test_row_parallel_linear
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
*=
2
print_separator
(
'test parallel self-attention'
)
print_separator
(
'test parallel self-attention'
)
intra_laye
r_model_parallel_size
=
1
tenso
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
while
tenso
r_model_parallel_size
<=
world_size
:
test_parallel_self_attention
(
intra_laye
r_model_parallel_size
)
test_parallel_self_attention
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
*=
2
print_separator
(
'test parallel transformer'
)
print_separator
(
'test parallel transformer'
)
intra_laye
r_model_parallel_size
=
1
tenso
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
while
tenso
r_model_parallel_size
<=
world_size
:
test_parallel_transformer_layer
(
intra_laye
r_model_parallel_size
)
test_parallel_transformer_layer
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
*=
2
megatron/mpu/tests/test_random.py
View file @
52a5f2f2
...
@@ -21,14 +21,14 @@ import sys
...
@@ -21,14 +21,14 @@ import sys
sys
.
path
.
append
(
"../.."
)
sys
.
path
.
append
(
"../.."
)
def
test_set_cuda_rng_state
(
intra_laye
r_model_parallel_size
):
def
test_set_cuda_rng_state
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing set_rng_state with size {} ...'
.
print
(
'> testing set_rng_state with size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
format
(
tenso
r_model_parallel_size
))
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
size
=
123
size
=
123
seed
=
1234
seed
=
1234
...
@@ -83,14 +83,14 @@ def test_set_cuda_rng_state(intra_layer_model_parallel_size):
...
@@ -83,14 +83,14 @@ def test_set_cuda_rng_state(intra_layer_model_parallel_size):
print
(
'>> passed the test :-)'
)
print
(
'>> passed the test :-)'
)
def
test_cuda_rng_tracker
(
intra_laye
r_model_parallel_size
):
def
test_cuda_rng_tracker
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing cuda rng tracker with size {} ...'
.
print
(
'> testing cuda rng tracker with size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
format
(
tenso
r_model_parallel_size
))
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
seed_1
=
1234
seed_1
=
1234
seed_2
=
4321
seed_2
=
4321
...
@@ -154,20 +154,20 @@ def test_cuda_rng_tracker(intra_layer_model_parallel_size):
...
@@ -154,20 +154,20 @@ def test_cuda_rng_tracker(intra_layer_model_parallel_size):
print
(
'>> passed the test :-)'
)
print
(
'>> passed the test :-)'
)
def
test_
intra_layer_
model_parallel_cuda_manual_seed
(
intra_laye
r_model_parallel_size
):
def
test_model_parallel_cuda_manual_seed
(
tenso
r_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing model parallel cuda manual seed with size {} ...'
.
print
(
'> testing model parallel cuda manual seed with size {} ...'
.
format
(
intra_laye
r_model_parallel_size
))
format
(
tenso
r_model_parallel_size
))
mpu
.
initialize_model_parallel
(
intra_laye
r_model_parallel_size
)
mpu
.
initialize_model_parallel
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
=
mpu
.
get_
intra_laye
r_model_parallel_world_size
()
tenso
r_model_parallel_size
=
mpu
.
get_
tenso
r_model_parallel_world_size
()
mpu
.
intra_layer_
model_parallel_cuda_manual_seed
(
12345
)
mpu
.
model_parallel_cuda_manual_seed
(
12345
)
assert
torch
.
cuda
.
initial_seed
()
==
12345
assert
torch
.
cuda
.
initial_seed
()
==
12345
with
mpu
.
get_cuda_rng_tracker
().
fork
():
with
mpu
.
get_cuda_rng_tracker
().
fork
():
assert
torch
.
cuda
.
initial_seed
()
==
(
12345
+
2718
+
assert
torch
.
cuda
.
initial_seed
()
==
(
12345
+
2718
+
mpu
.
get_
intra_laye
r_model_parallel_rank
())
mpu
.
get_
tenso
r_model_parallel_rank
())
# Reset the tracker
# Reset the tracker
mpu
.
get_cuda_rng_tracker
().
reset
()
mpu
.
get_cuda_rng_tracker
().
reset
()
...
@@ -185,20 +185,20 @@ if __name__ == '__main__':
...
@@ -185,20 +185,20 @@ if __name__ == '__main__':
initialize_distributed
()
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
intra_laye
r_model_parallel_size
=
1
tenso
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test set rng state'
)
print_separator
(
'test set rng state'
)
test_set_cuda_rng_state
(
intra_laye
r_model_parallel_size
)
test_set_cuda_rng_state
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
*=
2
intra_laye
r_model_parallel_size
=
1
tenso
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test cuda rng tracker'
)
print_separator
(
'test cuda rng tracker'
)
test_cuda_rng_tracker
(
intra_laye
r_model_parallel_size
)
test_cuda_rng_tracker
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
*=
2
intra_laye
r_model_parallel_size
=
1
tenso
r_model_parallel_size
=
1
while
intra_laye
r_model_parallel_size
<=
world_size
:
while
tenso
r_model_parallel_size
<=
world_size
:
print_separator
(
'test model parallel cuda manual seed'
)
print_separator
(
'test model parallel cuda manual seed'
)
test_
intra_layer_
model_parallel_cuda_manual_seed
(
intra_laye
r_model_parallel_size
)
test_model_parallel_cuda_manual_seed
(
tenso
r_model_parallel_size
)
intra_laye
r_model_parallel_size
*=
2
tenso
r_model_parallel_size
*=
2
megatron/text_generation_utils.py
View file @
52a5f2f2
...
@@ -88,7 +88,7 @@ def generate_samples_input_from_file(model):
...
@@ -88,7 +88,7 @@ def generate_samples_input_from_file(model):
# Read the sample file and open the output file.
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
'sample input file is not provided.'
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
input_count
=
len
(
all_raw_text
)
...
@@ -105,10 +105,10 @@ def generate_samples_input_from_file(model):
...
@@ -105,10 +105,10 @@ def generate_samples_input_from_file(model):
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
while
True
:
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
terminate_runs
=
0
terminate_runs
=
0
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
raw_text
=
all_raw_text
[
input_pos
]
raw_text
=
all_raw_text
[
input_pos
]
input_pos
+=
1
input_pos
+=
1
if
input_pos
==
input_count
:
if
input_pos
==
input_count
:
...
@@ -131,8 +131,8 @@ def generate_samples_input_from_file(model):
...
@@ -131,8 +131,8 @@ def generate_samples_input_from_file(model):
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_
intra_laye
r_model_parallel_src_rank
(),
mpu
.
get_
tenso
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
if
terminate_runs
==
1
:
if
terminate_runs
==
1
:
...
@@ -143,7 +143,7 @@ def generate_samples_input_from_file(model):
...
@@ -143,7 +143,7 @@ def generate_samples_input_from_file(model):
decode_tokens
,
_
=
decode_tokens
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
detokenize
(
trim_decode_tokens
=
tokenizer
.
detokenize
(
...
@@ -158,7 +158,7 @@ def generate_samples_input_from_file(model):
...
@@ -158,7 +158,7 @@ def generate_samples_input_from_file(model):
raw_text
=
None
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
context_count
+=
1
context_count
+=
1
...
@@ -171,10 +171,10 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -171,10 +171,10 @@ def generate_samples_interactive(model, print_frequency=24):
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
while
True
:
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
terminate_runs
=
0
terminate_runs
=
0
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
os
.
system
(
'clear'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
while
not
raw_text
:
while
not
raw_text
:
...
@@ -198,8 +198,8 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -198,8 +198,8 @@ def generate_samples_interactive(model, print_frequency=24):
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_
intra_laye
r_model_parallel_src_rank
(),
mpu
.
get_
tenso
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
terminate_runs
=
terminate_runs_tensor
[
0
].
item
()
if
terminate_runs
==
1
:
if
terminate_runs
==
1
:
...
@@ -210,7 +210,7 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -210,7 +210,7 @@ def generate_samples_interactive(model, print_frequency=24):
decode_tokens
,
_
=
decode_tokens
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
and
\
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
and
\
counter
%
print_frequency
==
0
:
counter
%
print_frequency
==
0
:
os
.
system
(
'clear'
)
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
...
@@ -218,7 +218,7 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -218,7 +218,7 @@ def generate_samples_interactive(model, print_frequency=24):
decode_tokens
)[
len
(
raw_text
):]
decode_tokens
)[
len
(
raw_text
):]
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
os
.
system
(
'clear'
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
print
(
"
\n
Context:"
,
raw_text
,
flush
=
True
)
trim_decode_tokens
=
tokenizer
.
detokenize
(
trim_decode_tokens
=
tokenizer
.
detokenize
(
...
@@ -226,10 +226,10 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -226,10 +226,10 @@ def generate_samples_interactive(model, print_frequency=24):
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
print
(
"
\n
Megatron-LM:"
,
trim_decode_tokens
,
flush
=
True
)
raw_text
=
None
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
context_count
+=
1
context_count
+=
1
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
input
(
"
\n
Press any key to continue >>>"
)
input
(
"
\n
Press any key to continue >>>"
)
...
@@ -299,11 +299,11 @@ def get_token_stream(model, context_tokens):
...
@@ -299,11 +299,11 @@ def get_token_stream(model, context_tokens):
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
torch
.
distributed
.
broadcast
(
context_length_tensor
,
torch
.
distributed
.
broadcast
(
context_length_tensor
,
mpu
.
get_
intra_laye
r_model_parallel_src_rank
(),
mpu
.
get_
tenso
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
mpu
.
get_
intra_laye
r_model_parallel_src_rank
(),
mpu
.
get_
tenso
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
context_length
=
context_length_tensor
.
min
().
item
()
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
...
...
megatron/tokenizer/tokenizer.py
View file @
52a5f2f2
...
@@ -56,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args):
...
@@ -56,7 +56,7 @@ def _vocab_size_with_padding(orig_vocab_size, args):
after
=
orig_vocab_size
after
=
orig_vocab_size
multiple
=
args
.
make_vocab_size_divisible_by
*
\
multiple
=
args
.
make_vocab_size_divisible_by
*
\
args
.
intra_laye
r_model_parallel_size
args
.
tenso
r_model_parallel_size
while
(
after
%
multiple
)
!=
0
:
while
(
after
%
multiple
)
!=
0
:
after
+=
1
after
+=
1
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
...
...
megatron/training.py
View file @
52a5f2f2
...
@@ -124,10 +124,10 @@ def get_model(model_provider_func):
...
@@ -124,10 +124,10 @@ def get_model(model_provider_func):
# Print number of parameters.
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' > number of parameters on (
intra-layer, inter-layer
) '
print
(
' > number of parameters on (
tensor, pipeline
) '
'model parallel rank ({}, {}): {}'
.
format
(
'model parallel rank ({}, {}): {}'
.
format
(
mpu
.
get_
intra_laye
r_model_parallel_rank
(),
mpu
.
get_
tenso
r_model_parallel_rank
(),
mpu
.
get_
inter_layer
_model_parallel_rank
(),
mpu
.
get_
pipeline
_model_parallel_rank
(),
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])),
flush
=
True
)
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])),
flush
=
True
)
# GPU allocation.
# GPU allocation.
...
@@ -166,8 +166,8 @@ def get_optimizer(model):
...
@@ -166,8 +166,8 @@ def get_optimizer(model):
# Add model parallel attribute if it is not set.
# Add model parallel attribute if it is not set.
for
param_group
in
param_groups
:
for
param_group
in
param_groups
:
for
param
in
param_group
[
'params'
]:
for
param
in
param_group
[
'params'
]:
if
not
hasattr
(
param
,
'
intra_laye
r_model_parallel'
):
if
not
hasattr
(
param
,
'
tenso
r_model_parallel'
):
param
.
intra_laye
r_model_parallel
=
False
param
.
tenso
r_model_parallel
=
False
# Use Adam.
# Use Adam.
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
...
@@ -260,7 +260,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
...
@@ -260,7 +260,7 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
tensor_recv_prev
=
tensor_recv_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_send_next
=
tensor_send_next
,
tensor_recv_next
=
tensor_recv_next
,
tensor_recv_next
=
tensor_recv_next
,
group
=
mpu
.
get_
inter_layer
_model_parallel_group
())
group
=
mpu
.
get_
pipeline
_model_parallel_group
())
return
tensor_recv_prev
,
tensor_recv_next
return
tensor_recv_prev
,
tensor_recv_next
...
@@ -304,7 +304,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -304,7 +304,7 @@ def train_step(forward_step_func, data_iterator,
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
# Compute number of microbatches in a minibatch.
# Compute number of microbatches in a minibatch.
num_microbatches_to_pipeline
=
args
.
inter_layer
_model_parallel_size
\
num_microbatches_to_pipeline
=
args
.
pipeline
_model_parallel_size
\
if
args
.
use_pipelining
else
1
if
args
.
use_pipelining
else
1
input_tensors
=
[]
input_tensors
=
[]
...
@@ -313,7 +313,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -313,7 +313,7 @@ def train_step(forward_step_func, data_iterator,
# Run forward pass for all microbatches in minibatch.
# Run forward pass for all microbatches in minibatch.
for
i
in
range
(
num_microbatches_to_pipeline
):
for
i
in
range
(
num_microbatches_to_pipeline
):
if
not
mpu
.
is_
inter_layer
_first_stage
():
if
not
mpu
.
is_
pipeline
_first_stage
():
input_tensor
,
_
=
communicate
(
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
...
@@ -327,7 +327,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -327,7 +327,7 @@ def train_step(forward_step_func, data_iterator,
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
timers
(
'forward'
).
stop
()
timers
(
'forward'
).
stop
()
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
loss
,
loss_reduced
=
output_tensor
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
output_tensor
=
loss
losses_reduced
.
append
(
loss_reduced
)
losses_reduced
.
append
(
loss_reduced
)
...
@@ -346,7 +346,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -346,7 +346,7 @@ def train_step(forward_step_func, data_iterator,
input_tensor
=
input_tensors
.
pop
(
0
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
output_grad_tensor
=
None
output_grad_tensor
=
None
else
:
else
:
_
,
output_grad_tensor
=
communicate
(
_
,
output_grad_tensor
=
communicate
(
...
@@ -362,7 +362,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -362,7 +362,7 @@ def train_step(forward_step_func, data_iterator,
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_grad_tensor
)
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_grad_tensor
)
timers
(
'backward'
).
stop
()
timers
(
'backward'
).
stop
()
if
not
mpu
.
is_
inter_layer
_first_stage
():
if
not
mpu
.
is_
pipeline
_first_stage
():
communicate
(
communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_grad_tensor
,
tensor_send_prev
=
input_grad_tensor
,
...
@@ -383,8 +383,8 @@ def train_step(forward_step_func, data_iterator,
...
@@ -383,8 +383,8 @@ def train_step(forward_step_func, data_iterator,
timers
(
'backward-master-grad'
).
stop
()
timers
(
'backward-master-grad'
).
stop
()
# All-reduce across first and last stages.
# All-reduce across first and last stages.
if
(
mpu
.
is_
inter_layer
_first_stage
()
or
mpu
.
is_
inter_layer
_last_stage
())
and
\
if
(
mpu
.
is_
pipeline
_first_stage
()
or
mpu
.
is_
pipeline
_last_stage
())
and
\
args
.
inter_layer
_model_parallel_size
>
1
:
args
.
pipeline
_model_parallel_size
>
1
:
unwrapped_model
=
model
unwrapped_model
=
model
while
isinstance
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16_Module
)):
while
isinstance
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16_Module
)):
unwrapped_model
=
unwrapped_model
.
module
unwrapped_model
=
unwrapped_model
.
module
...
@@ -421,7 +421,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -421,7 +421,7 @@ def train_step(forward_step_func, data_iterator,
else
:
else
:
skipped_iter
=
1
skipped_iter
=
1
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
# Average loss across microbatches.
# Average loss across microbatches.
loss_reduced
=
{}
loss_reduced
=
{}
for
key
in
losses_reduced
[
0
]:
for
key
in
losses_reduced
[
0
]:
...
@@ -604,7 +604,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -604,7 +604,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
args
.
eval_iters
))
args
.
eval_iters
))
if
not
mpu
.
is_
inter_layer
_first_stage
():
if
not
mpu
.
is_
pipeline
_first_stage
():
input_tensor
,
_
=
communicate
(
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
...
@@ -616,7 +616,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -616,7 +616,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
# Forward evaluation.
# Forward evaluation.
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
_
,
loss_dict
=
output_tensor
_
,
loss_dict
=
output_tensor
# Reduce across processes.
# Reduce across processes.
for
key
in
loss_dict
:
for
key
in
loss_dict
:
...
@@ -671,7 +671,7 @@ def build_train_valid_test_data_iterators(
...
@@ -671,7 +671,7 @@ def build_train_valid_test_data_iterators(
print_rank_0
(
'> building train, validation, and test datasets ...'
)
print_rank_0
(
'> building train, validation, and test datasets ...'
)
# Data loader only on rank 0 of each model parallel group.
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_
intra_laye
r_model_parallel_rank
()
==
0
:
if
mpu
.
get_
tenso
r_model_parallel_rank
()
==
0
:
# Rank, size, and global batch size.
# Rank, size, and global batch size.
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_size
=
mpu
.
get_data_parallel_world_size
()
global_batch_size
=
args
.
batch_size
*
data_parallel_size
global_batch_size
=
args
.
batch_size
*
data_parallel_size
...
@@ -709,8 +709,8 @@ def build_train_valid_test_data_iterators(
...
@@ -709,8 +709,8 @@ def build_train_valid_test_data_iterators(
# Broadcast num tokens.
# Broadcast num tokens.
torch
.
distributed
.
broadcast
(
flags
,
torch
.
distributed
.
broadcast
(
flags
,
mpu
.
get_
intra_laye
r_model_parallel_src_rank
(),
mpu
.
get_
tenso
r_model_parallel_src_rank
(),
group
=
mpu
.
get_
intra_laye
r_model_parallel_group
())
group
=
mpu
.
get_
tenso
r_model_parallel_group
())
args
.
do_train
=
flags
[
0
].
item
()
args
.
do_train
=
flags
[
0
].
item
()
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_valid
=
flags
[
1
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
args
.
do_test
=
flags
[
2
].
item
()
...
...
megatron/utils.py
View file @
52a5f2f2
...
@@ -58,7 +58,7 @@ def print_params_min_max_norm(optimizer, iteration):
...
@@ -58,7 +58,7 @@ def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
"""Print min, max, and norm of all parameters."""
index
=
0
index
=
0
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
string
=
'iteration, rank, index,
intra-laye
r-model-parallel, min, max, norm
\n
'
string
=
'iteration, rank, index,
tenso
r-model-parallel, min, max, norm
\n
'
optimizer_
=
optimizer
optimizer_
=
optimizer
if
isinstance
(
optimizer
,
FP16_Optimizer
):
if
isinstance
(
optimizer
,
FP16_Optimizer
):
optimizer_
=
optimizer
.
optimizer
optimizer_
=
optimizer
.
optimizer
...
@@ -69,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration):
...
@@ -69,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration):
max_
=
param
.
data
.
max
()
max_
=
param
.
data
.
max
()
norm
=
param
.
data
.
norm
()
norm
=
param
.
data
.
norm
()
string
+=
'{:7d}, {:4d}, {:4d}, {:2d}, '
.
format
(
string
+=
'{:7d}, {:4d}, {:4d}, {:2d}, '
.
format
(
iteration
,
rank
,
index
,
int
(
param
.
intra_laye
r_model_parallel
))
iteration
,
rank
,
index
,
int
(
param
.
tenso
r_model_parallel
))
string
+=
'{:.6E}, {:.6E}, {:.6E}
\n
'
.
format
(
min_
,
max_
,
norm
)
string
+=
'{:.6E}, {:.6E}, {:.6E}
\n
'
.
format
(
min_
,
max_
,
norm
)
print
(
string
,
flush
=
True
)
print
(
string
,
flush
=
True
)
...
...
pretrain_bert.py
View file @
52a5f2f2
...
@@ -34,12 +34,12 @@ def model_provider():
...
@@ -34,12 +34,12 @@ def model_provider():
print_rank_0
(
'building BERT model ...'
)
print_rank_0
(
'building BERT model ...'
)
args
=
get_args
()
args
=
get_args
()
if
args
.
inter_layer
_model_parallel_size
>
1
:
if
args
.
pipeline
_model_parallel_size
>
1
:
# Determine model based on position of stage in pipeline.
# Determine model based on position of stage in pipeline.
if
mpu
.
is_
inter_layer
_first_stage
():
if
mpu
.
is_
pipeline
_first_stage
():
model
=
BertModelFirstStage
(
model
=
BertModelFirstStage
(
num_tokentypes
=
2
)
num_tokentypes
=
2
)
elif
mpu
.
is_
inter_layer
_last_stage
():
elif
mpu
.
is_
pipeline
_last_stage
():
model
=
BertModelLastStage
(
model
=
BertModelLastStage
(
num_tokentypes
=
2
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
add_binary_head
=
True
,
...
@@ -93,21 +93,21 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -93,21 +93,21 @@ def forward_step(data_iterator, model, input_tensor):
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward pass through the model.
# Forward pass through the model.
if
mpu
.
is_
inter_layer
_first_stage
():
if
mpu
.
is_
pipeline
_first_stage
():
assert
input_tensor
is
None
assert
input_tensor
is
None
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
output_tensor
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
,
output_tensor
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
,
lm_labels
=
lm_labels
)
lm_labels
=
lm_labels
)
else
:
else
:
output_tensor
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
)
output_tensor
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
)
elif
mpu
.
is_
inter_layer
_last_stage
():
elif
mpu
.
is_
pipeline
_last_stage
():
assert
input_tensor
is
not
None
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
padding_mask
,
lm_labels
=
lm_labels
)
output_tensor
=
model
(
input_tensor
,
padding_mask
,
lm_labels
=
lm_labels
)
else
:
else
:
assert
input_tensor
is
not
None
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
padding_mask
)
output_tensor
=
model
(
input_tensor
,
padding_mask
)
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
lm_loss_
,
sop_logits
=
output_tensor
lm_loss_
,
sop_logits
=
output_tensor
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
float
(),
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
float
(),
...
...
pretrain_gpt2.py
View file @
52a5f2f2
...
@@ -33,11 +33,11 @@ def model_provider():
...
@@ -33,11 +33,11 @@ def model_provider():
print_rank_0
(
'building GPT2 model ...'
)
print_rank_0
(
'building GPT2 model ...'
)
args
=
get_args
()
args
=
get_args
()
if
args
.
inter_layer
_model_parallel_size
>
1
:
if
args
.
pipeline
_model_parallel_size
>
1
:
# Determine model based on position of stage in pipeline.
# Determine model based on position of stage in pipeline.
if
mpu
.
is_
inter_layer
_first_stage
():
if
mpu
.
is_
pipeline
_first_stage
():
model
=
GPT2ModelFirstStage
(
num_tokentypes
=
0
)
model
=
GPT2ModelFirstStage
(
num_tokentypes
=
0
)
elif
mpu
.
is_
inter_layer
_last_stage
():
elif
mpu
.
is_
pipeline
_last_stage
():
model
=
GPT2ModelLastStage
(
model
=
GPT2ModelLastStage
(
num_tokentypes
=
0
,
parallel_output
=
True
)
num_tokentypes
=
0
,
parallel_output
=
True
)
else
:
else
:
...
@@ -93,21 +93,21 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -93,21 +93,21 @@ def forward_step(data_iterator, model, input_tensor):
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward pass through the model.
# Forward pass through the model.
if
mpu
.
is_
inter_layer
_first_stage
():
if
mpu
.
is_
pipeline
_first_stage
():
assert
input_tensor
is
None
assert
input_tensor
is
None
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
labels
=
labels
)
else
:
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
)
elif
mpu
.
is_
inter_layer
_last_stage
():
elif
mpu
.
is_
pipeline
_last_stage
():
assert
input_tensor
is
not
None
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
,
labels
=
labels
)
output_tensor
=
model
(
input_tensor
,
attention_mask
,
labels
=
labels
)
else
:
else
:
assert
input_tensor
is
not
None
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
)
output_tensor
=
model
(
input_tensor
,
attention_mask
)
if
mpu
.
is_
inter_layer
_last_stage
():
if
mpu
.
is_
pipeline
_last_stage
():
losses
=
output_tensor
.
float
()
losses
=
output_tensor
.
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment