Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b2f57fc4
Commit
b2f57fc4
authored
Feb 17, 2021
by
Mostofa Patwary
Browse files
pulled latest megatron
parents
a4b628ab
76e3fca0
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1019 additions
and
398 deletions
+1019
-398
megatron/arguments.py
megatron/arguments.py
+37
-4
megatron/checkpointing.py
megatron/checkpointing.py
+23
-12
megatron/initialize.py
megatron/initialize.py
+2
-1
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+10
-5
megatron/model/module.py
megatron/model/module.py
+2
-2
megatron/model/transformer.py
megatron/model/transformer.py
+21
-1
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+3
-0
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+46
-3
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+14
-13
megatron/p2p_communication.py
megatron/p2p_communication.py
+272
-0
megatron/schedules.py
megatron/schedules.py
+424
-0
megatron/training.py
megatron/training.py
+110
-350
megatron/utils.py
megatron/utils.py
+18
-1
pretrain_bert.py
pretrain_bert.py
+14
-3
pretrain_gpt.py
pretrain_gpt.py
+13
-2
tools/generate_samples_gpt.py
tools/generate_samples_gpt.py
+3
-1
tools/merge_mp_partitions.py
tools/merge_mp_partitions.py
+7
-0
No files found.
megatron/arguments.py
View file @
b2f57fc4
...
@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
args
.
tensor_model_parallel_size
assert
args
.
world_size
%
model_parallel_size
==
0
,
'world size is not'
\
assert
args
.
world_size
%
model_parallel_size
==
0
,
'world size is not'
\
' divisible by tensor parallel size ({}) times pipeline paralle '
\
' divisible by tensor parallel size ({}) times pipeline paralle
l
'
\
'size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
'size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
pipeline_model_parallel_size
)
args
.
data_parallel_size
=
args
.
world_size
//
model_parallel_size
args
.
data_parallel_size
=
args
.
world_size
//
model_parallel_size
...
@@ -116,6 +116,18 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -116,6 +116,18 @@ def parse_args(extra_args_provider=None, defaults={},
print
(
'setting global batch size to {}'
.
format
(
print
(
'setting global batch size to {}'
.
format
(
args
.
global_batch_size
),
flush
=
True
)
args
.
global_batch_size
),
flush
=
True
)
assert
args
.
global_batch_size
>
0
assert
args
.
global_batch_size
>
0
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
assert
args
.
num_layers
%
args
.
num_layers_per_virtual_pipeline_stage
==
0
,
\
'number of layers is not divisible by number of layers per virtual '
\
'pipeline stage'
args
.
virtual_pipeline_model_parallel_size
=
\
(
args
.
num_layers
//
args
.
pipeline_model_parallel_size
)
//
\
args
.
num_layers_per_virtual_pipeline_stage
assert
args
.
global_batch_size
%
args
.
pipeline_model_parallel_size
==
0
,
\
'global batch size is not divisible by pipeline parallel size when '
\
'using interleaved schedule'
else
:
args
.
virtual_pipeline_model_parallel_size
=
None
# Parameters dtype.
# Parameters dtype.
args
.
params_dtype
=
torch
.
float
args
.
params_dtype
=
torch
.
float
...
@@ -202,7 +214,23 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -202,7 +214,23 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
checkpoint_activations
,
\
assert
args
.
checkpoint_activations
,
\
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
'need to enable checkpoint-activations'
# custom kernel constraints check
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
# constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
if
args
.
fp16
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
:
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.'
)
# Load scaled_masked_softmax_fusion_kernels
# Load scaled_masked_softmax_fusion_kernels
if
args
.
masked_softmax_fusion
:
if
args
.
masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
...
@@ -478,9 +506,9 @@ def _add_checkpointing_args(parser):
...
@@ -478,9 +506,9 @@ def _add_checkpointing_args(parser):
help
=
'Output directory to save checkpoints to.'
)
help
=
'Output directory to save checkpoints to.'
)
group
.
add_argument
(
'--save-interval'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--save-interval'
,
type
=
int
,
default
=
None
,
help
=
'Number of iterations between checkpoint saves.'
)
help
=
'Number of iterations between checkpoint saves.'
)
group
.
add_argument
(
'--no-save-optim'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-save-optim'
,
action
=
'store_true'
,
default
=
None
,
help
=
'Do not save current optimizer.'
)
help
=
'Do not save current optimizer.'
)
group
.
add_argument
(
'--no-save-rng'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-save-rng'
,
action
=
'store_true'
,
default
=
None
,
help
=
'Do not save current rng state.'
)
help
=
'Do not save current rng state.'
)
group
.
add_argument
(
'--load'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing a model checkpoint.'
)
help
=
'Directory containing a model checkpoint.'
)
...
@@ -541,6 +569,8 @@ def _add_distributed_args(parser):
...
@@ -541,6 +569,8 @@ def _add_distributed_args(parser):
group
.
add_argument
(
'--model-parallel-size'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--model-parallel-size'
,
type
=
int
,
default
=
None
,
help
=
'Old model parallel argument, do not use. Use '
help
=
'Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.'
)
'--tensor-model-parallel-size instead.'
)
group
.
add_argument
(
'--num-layers-per-virtual-pipeline-stage'
,
type
=
int
,
default
=
None
,
help
=
'Number of layers per virtual pipeline stage'
)
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
choices
=
[
'nccl'
,
'gloo'
],
choices
=
[
'nccl'
,
'gloo'
],
help
=
'Which backend to use for distributed training.'
)
help
=
'Which backend to use for distributed training.'
)
...
@@ -548,6 +578,9 @@ def _add_distributed_args(parser):
...
@@ -548,6 +578,9 @@ def _add_distributed_args(parser):
choices
=
[
'local'
,
'torch'
],
choices
=
[
'local'
,
'torch'
],
help
=
'which DistributedDataParallel implementation '
help
=
'which DistributedDataParallel implementation '
'to use.'
)
'to use.'
)
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
group
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
None
,
help
=
'local rank passed from distributed launcher.'
)
help
=
'local rank passed from distributed launcher.'
)
group
.
add_argument
(
'--lazy-mpu-init'
,
type
=
bool
,
required
=
False
,
group
.
add_argument
(
'--lazy-mpu-init'
,
type
=
bool
,
required
=
False
,
...
...
megatron/checkpointing.py
View file @
b2f57fc4
...
@@ -21,12 +21,12 @@ import sys
...
@@ -21,12 +21,12 @@ import sys
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
torchDDP
from
megatron
import
(
get_args
,
from
megatron
import
(
get_args
,
mpu
,
mpu
,
print_rank_0
,
print_rank_0
,
update_num_microbatches
)
update_num_microbatches
,
utils
)
_CHECKPOINT_VERSION
=
None
_CHECKPOINT_VERSION
=
None
...
@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args
=
get_args
()
args
=
get_args
()
# Only rank zero of the data parallel writes to the disk.
# Only rank zero of the data parallel writes to the disk.
if
isinstance
(
model
,
torchDDP
):
model
=
utils
.
unwrap_model
(
model
)
model
=
model
.
module
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
))
iteration
,
args
.
save
))
...
@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict
[
'args'
]
=
args
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
3.0
state_dict
[
'checkpoint_version'
]
=
3.0
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
if
len
(
model
)
==
1
:
state_dict
[
'model'
]
=
model
[
0
].
state_dict_for_save_checkpoint
()
else
:
for
i
in
range
(
len
(
model
)):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
state_dict
[
'model%d'
%
i
]
=
model
[
i
].
state_dict_for_save_checkpoint
()
# Optimizer stuff.
# Optimizer stuff.
if
not
args
.
no_save_optim
:
if
not
args
.
no_save_optim
:
...
@@ -238,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
...
@@ -238,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
args
=
get_args
()
args
=
get_args
()
load_dir
=
getattr
(
args
,
load_arg
)
load_dir
=
getattr
(
args
,
load_arg
)
if
isinstance
(
model
,
torchDDP
):
model
=
utils
.
unwrap_model
(
model
)
model
=
model
.
module
# Read the tracker file and set the iteration.
# Read the tracker file and set the iteration.
tracker_filename
=
get_checkpoint_tracker_filename
(
load_dir
)
tracker_filename
=
get_checkpoint_tracker_filename
(
load_dir
)
...
@@ -324,7 +328,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
...
@@ -324,7 +328,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0
(
'could not find arguments in the checkpoint ...'
)
print_rank_0
(
'could not find arguments in the checkpoint ...'
)
# Model.
# Model.
model
.
load_state_dict
(
state_dict
[
'model'
],
strict
=
strict
)
if
len
(
model
)
==
1
:
model
[
0
].
load_state_dict
(
state_dict
[
'model'
],
strict
=
strict
)
else
:
for
i
in
range
(
len
(
model
)):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
[
i
].
load_state_dict
(
state_dict
[
'model%d'
%
i
],
strict
=
strict
)
# Fix up query/key/value matrix ordering if needed
# Fix up query/key/value matrix ordering if needed
checkpoint_version
=
get_checkpoint_version
()
checkpoint_version
=
get_checkpoint_version
()
...
@@ -352,12 +361,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
...
@@ -352,12 +361,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
np
.
random
.
set_state
(
state_dict
[
'np_rng_state'
])
np
.
random
.
set_state
(
state_dict
[
'np_rng_state'
])
torch
.
set_rng_state
(
state_dict
[
'torch_rng_state'
])
torch
.
set_rng_state
(
state_dict
[
'torch_rng_state'
])
torch
.
cuda
.
set_rng_state
(
state_dict
[
'cuda_rng_state'
])
torch
.
cuda
.
set_rng_state
(
state_dict
[
'cuda_rng_state'
])
# Check for empty states array
if
not
state_dict
[
'rng_tracker_states'
]:
raise
KeyError
mpu
.
get_cuda_rng_tracker
().
set_states
(
mpu
.
get_cuda_rng_tracker
().
set_states
(
state_dict
[
'rng_tracker_states'
])
state_dict
[
'rng_tracker_states'
])
except
KeyError
:
except
KeyError
:
print_rank_0
(
'Unable to load
optimizer
from checkpoint {}. '
print_rank_0
(
'Unable to load
rng state
from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the
optimizer
state, '
'attempting to load the
rng
state, '
'exiting ...'
.
format
(
checkpoint_name
))
'exiting ...'
.
format
(
checkpoint_name
))
sys
.
exit
()
sys
.
exit
()
...
@@ -376,8 +388,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f
...
@@ -376,8 +388,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f
args
=
get_args
()
args
=
get_args
()
if
isinstance
(
model
,
torchDDP
):
model
=
utils
.
unwrap_model
(
model
)
model
=
model
.
module
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
...
...
megatron/initialize.py
View file @
b2f57fc4
...
@@ -133,7 +133,8 @@ def _initialize_distributed():
...
@@ -133,7 +133,8 @@ def _initialize_distributed():
print
(
'model parallel is already initialized'
)
print
(
'model parallel is already initialized'
)
else
:
else
:
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
)
def
_init_autoresume
():
def
_init_autoresume
():
...
...
megatron/model/fused_softmax.py
View file @
b2f57fc4
...
@@ -113,18 +113,23 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -113,18 +113,23 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
(
assert
(
self
.
scale
is
None
or
softmax_in_fp32
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
data_size
=
input
.
size
()
data_size
=
input
.
size
()
query_seq_len
=
data_size
[
-
2
]
query_seq_len
=
data_size
[
-
2
]
key_seq_len
=
data_size
[
-
1
]
key_seq_len
=
data_size
[
-
1
]
a
ssert
input
.
dim
()
==
4
a
ttn_batch_size
=
data_size
[
0
]
*
data_size
[
1
]
# invoke custom kernel
# constraints on various tensor dimensions to enable warp based
if
self
.
input_in_fp16
and
key_seq_len
<=
2048
and
mask
is
not
None
and
\
# optimization and upper triangular optimization (for causal mask)
query_seq_len
%
4
==
0
and
self
.
scaled_masked_softmax_fusion
:
custom_kernel_constraint
=
key_seq_len
>
16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
if
self
.
input_in_fp16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
...
...
megatron/model/module.py
View file @
b2f57fc4
...
@@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module):
...
@@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module):
def
word_embeddings_weight
(
self
):
def
word_embeddings_weight
(
self
):
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
not
self
.
share_word_embeddings
:
if
not
self
.
share_word_embeddings
:
raise
Exception
(
'word_embeddings_weight() called for last '
raise
Exception
(
'word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false'
)
'stage, but share_word_embeddings is false'
)
...
...
megatron/model/transformer.py
View file @
b2f57fc4
...
@@ -552,7 +552,27 @@ class ParallelTransformer(MegatronModule):
...
@@ -552,7 +552,27 @@ class ParallelTransformer(MegatronModule):
layer_number
,
layer_number
,
layer_type
=
layer_type
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
self_attn_mask_type
=
self_attn_mask_type
)
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
...
...
megatron/mpu/__init__.py
View file @
b2f57fc4
...
@@ -38,6 +38,7 @@ from .initialize import get_pipeline_model_parallel_next_rank
...
@@ -38,6 +38,7 @@ from .initialize import get_pipeline_model_parallel_next_rank
from
.initialize
import
get_pipeline_model_parallel_prev_rank
from
.initialize
import
get_pipeline_model_parallel_prev_rank
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
get_virtual_pipeline_model_parallel_rank
,
set_virtual_pipeline_model_parallel_rank
from
.initialize
import
initialize_model_parallel
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.initialize
import
model_parallel_is_initialized
...
@@ -58,6 +59,8 @@ from .random import get_cuda_rng_tracker
...
@@ -58,6 +59,8 @@ from .random import get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
gather_split_1d_tensor
from
.random
import
split_tensor_into_1d_equal_chunks
from
.utils
import
divide
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
split_tensor_along_last_dim
megatron/mpu/initialize.py
View file @
b2f57fc4
...
@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None
...
@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
# These values enable us to change the mpu sizes on the fly.
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
...
@@ -48,7 +51,8 @@ def is_unitialized():
...
@@ -48,7 +51,8 @@ def is_unitialized():
def
initialize_model_parallel
(
tensor_model_parallel_size_
=
1
,
def
initialize_model_parallel
(
tensor_model_parallel_size_
=
1
,
pipeline_model_parallel_size_
=
1
):
pipeline_model_parallel_size_
=
1
,
virtual_pipeline_model_parallel_size_
=
None
):
"""
"""
Initialize model data parallel groups.
Initialize model data parallel groups.
...
@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size_
is
not
None
:
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
virtual_pipeline_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
# Build the data-parallel groups.
# Build the data-parallel groups.
...
@@ -258,17 +268,46 @@ def get_pipeline_model_parallel_rank():
...
@@ -258,17 +268,46 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
def
is_pipeline_first_stage
():
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
and
\
get_virtual_pipeline_model_parallel_rank
()
!=
0
:
return
False
return
get_pipeline_model_parallel_rank
()
==
0
return
get_pipeline_model_parallel_rank
()
==
0
def
is_pipeline_last_stage
():
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
virtual_pipeline_model_parallel_world_size
=
\
get_virtual_pipeline_model_parallel_world_size
()
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
\
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
):
return
False
return
get_pipeline_model_parallel_rank
()
==
(
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
get_pipeline_model_parallel_world_size
()
-
1
)
def
get_virtual_pipeline_model_parallel_rank
():
"""Return the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def
set_virtual_pipeline_model_parallel_rank
(
rank
):
"""Set the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
rank
def
get_virtual_pipeline_model_parallel_world_size
():
"""Return the virtual pipeline-parallel world size."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def
get_tensor_model_parallel_src_rank
():
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
in the tensor model parallel group."""
...
@@ -276,11 +315,13 @@ def get_tensor_model_parallel_src_rank():
...
@@ -276,11 +315,13 @@ def get_tensor_model_parallel_src_rank():
local_world_size
=
get_tensor_model_parallel_world_size
()
local_world_size
=
get_tensor_model_parallel_world_size
()
return
(
global_rank
//
local_world_size
)
*
local_world_size
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_pipeline_model_parallel_first_rank
():
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
def
get_pipeline_model_parallel_last_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
"Pipeline parallel group is not initialized"
...
@@ -294,6 +335,7 @@ def get_pipeline_model_parallel_next_rank():
...
@@ -294,6 +335,7 @@ def get_pipeline_model_parallel_next_rank():
world_size
=
get_pipeline_model_parallel_world_size
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
def
get_pipeline_model_parallel_prev_rank
():
def
get_pipeline_model_parallel_prev_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
"Pipeline parallel group is not initialized"
...
@@ -301,6 +343,7 @@ def get_pipeline_model_parallel_prev_rank():
...
@@ -301,6 +343,7 @@ def get_pipeline_model_parallel_prev_rank():
world_size
=
get_pipeline_model_parallel_world_size
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
def
get_data_parallel_world_size
():
def
get_data_parallel_world_size
():
"""Return world size for the data parallel group."""
"""Return world size for the data parallel group."""
return
torch
.
distributed
.
get_world_size
(
group
=
get_data_parallel_group
())
return
torch
.
distributed
.
get_world_size
(
group
=
get_data_parallel_group
())
...
...
megatron/optimizer/__init__.py
View file @
b2f57fc4
...
@@ -23,7 +23,7 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
...
@@ -23,7 +23,7 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from
.optimizer
import
FP16OptimizerWithFP16Params
,
FP32Optimizer
from
.optimizer
import
FP16OptimizerWithFP16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
module
):
def
_get_params_for_weight_decay_optimization
(
module
s
):
"""Divide params into with-weight-decay and without-weight-decay groups.
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
Layernorms and baises will have no weight decay but the rest will.
"""
"""
...
@@ -32,18 +32,19 @@ def _get_params_for_weight_decay_optimization(module):
...
@@ -32,18 +32,19 @@ def _get_params_for_weight_decay_optimization(module):
weight_decay_params
=
{
'params'
:
[]}
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module_
in
module
.
modules
():
for
module
in
modules
:
if
isinstance
(
module_
,
LayerNorm
):
for
module_
in
module
.
modules
():
no_weight_decay_params
[
'params'
].
extend
(
if
isinstance
(
module_
,
LayerNorm
):
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
no_weight_decay_params
[
'params'
].
extend
(
if
p
is
not
None
])
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
else
:
if
p
is
not
None
])
weight_decay_params
[
'params'
].
extend
(
else
:
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
weight_decay_params
[
'params'
].
extend
(
if
p
is
not
None
and
n
!=
'bias'
])
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
no_weight_decay_params
[
'params'
].
extend
(
if
p
is
not
None
and
n
!=
'bias'
])
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
no_weight_decay_params
[
'params'
].
extend
(
if
p
is
not
None
and
n
==
'bias'
])
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
return
weight_decay_params
,
no_weight_decay_params
...
...
megatron/p2p_communication.py
0 → 100644
View file @
b2f57fc4
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
reduce
import
operator
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
use_ring_exchange
=
False
):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
args
=
get_args
()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
args
.
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
\
mpu
.
get_tensor_model_parallel_world_size
()
else
:
tensor_chunk_shape
=
tensor_shape
dtype
=
args
.
params_dtype
if
args
.
fp32_residual_connection
:
dtype
=
torch
.
float
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
if
recv_next
:
tensor_recv_next
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
if
tensor_send_prev
is
not
None
:
tensor_send_prev
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
if
use_ring_exchange
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_recv_next
=
tensor_recv_next
,
group
=
mpu
.
get_pipeline_model_parallel_group
())
else
:
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
# If using scatter-gather optimization, gather smaller chunks.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
if
recv_next
:
tensor_recv_next
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
timers
=
None
,
use_ring_exchange
=
False
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
if
timers
is
not
None
:
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_next
=
False
,
use_ring_exchange
=
use_ring_exchange
)
if
timers
is
not
None
:
timers
(
'forward-recv'
).
stop
()
return
input_tensor
def
recv_backward
(
timers
=
None
,
use_ring_exchange
=
False
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
if
timers
is
not
None
:
timers
(
'backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
use_ring_exchange
=
use_ring_exchange
)
if
timers
is
not
None
:
timers
(
'backward-recv'
).
stop
()
return
output_tensor_grad
def
send_forward
(
output_tensor
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
timers
(
'forward-send'
).
start
()
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
,
use_ring_exchange
=
use_ring_exchange
)
if
timers
is
not
None
:
timers
(
'forward-send'
).
stop
()
def
send_backward
(
input_tensor_grad
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
timers
(
'backward-send'
).
start
()
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
,
use_ring_exchange
=
use_ring_exchange
)
if
timers
is
not
None
:
timers
(
'backward-send'
).
stop
()
def
send_forward_recv_backward
(
output_tensor
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Batched send and recv with next rank in pipeline."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
use_ring_exchange
=
use_ring_exchange
)
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
stop
()
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Batched send and recv with previous rank in pipeline."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
,
use_ring_exchange
=
use_ring_exchange
)
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
stop
()
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
=
None
):
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
use_ring_exchange
=
True
)
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
stop
()
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
=
None
):
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
,
use_ring_exchange
=
True
)
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
stop
()
return
output_tensor_grad
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
,
timers
=
None
):
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
input_tensor
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
use_ring_exchange
=
True
)
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
return
input_tensor
,
output_tensor_grad
megatron/schedules.py
0 → 100644
View file @
b2f57fc4
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
megatron
import
get_args
from
megatron
import
get_num_microbatches
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
p2p_communication
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
timers
=
get_timers
()
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
timers
(
'forward-compute'
).
stop
()
return
output_tensor
def
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
args
=
get_args
()
timers
=
get_timers
()
timers
(
'backward-compute'
).
start
()
# Retain the grad on the input_tensor.
if
input_tensor
is
not
None
:
input_tensor
.
retain_grad
()
# Backward pass.
if
output_tensor_grad
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
# Collect the grad of the input_tensor.
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
timers
(
'backward-compute'
).
stop
()
return
input_tensor_grad
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
assert
len
(
model
)
==
1
model
=
model
[
0
]
losses_reduced
=
[]
for
i
in
range
(
get_num_microbatches
()):
input_tensor
,
output_tensor_grad
=
None
,
None
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
losses_reduced
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
losses_reduced
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
all_warmup_microbatches
=
False
if
forward_only
:
num_warmup_microbatches
=
num_microbatches
else
:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if
get_num_microbatches
()
==
pipeline_parallel_size
:
num_warmup_microbatches
=
num_microbatches
all_warmup_microbatches
=
True
else
:
num_warmup_microbatches
=
\
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
def
get_model_chunk_id
(
microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
if
not
forward
:
model_chunk_id
=
(
num_model_chunks
-
model_chunk_id
-
1
)
return
model_chunk_id
def
forward_step_helper
(
microbatch_id
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
\
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
return
output_tensor
def
backward_step_helper
(
microbatch_id
):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_last_stage
():
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
input_tensor_grad
# Run warmup forward passes.
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
timers
,
use_ring_exchange
=
True
))
for
k
in
range
(
num_warmup_microbatches
):
output_tensor
=
forward_step_helper
(
k
)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_forward_model_chunk_id
==
0
:
recv_prev
=
False
if
k
==
(
num_microbatches
-
1
):
recv_prev
=
False
# Don't send tensor downstream if on last stage.
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
\
not
all_warmup_microbatches
:
input_tensor_grad
=
None
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
input_tensor
,
output_tensor_grad
=
\
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
input_tensor
=
\
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
# Run 1F1B in steady state.
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
output_tensor
=
forward_step_helper
(
forward_k
)
# Backward pass.
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
None
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
mpu
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev
=
True
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if
k
==
(
num_microbatches_remaining
-
1
):
recv_prev
=
False
# Communicate tensors.
input_tensor
,
output_tensor_grad
=
\
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
# Run cooldown backward passes (flush out pipeline).
if
not
forward_only
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
timers
,
use_ring_exchange
=
True
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_next
=
False
if
k
==
(
num_microbatches
-
1
):
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
))
return
losses_reduced
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
timers
=
get_timers
()
assert
len
(
model
)
==
1
model
=
model
[
0
]
# Compute number of warmup microbatches.
num_microbatches
=
get_num_microbatches
()
num_warmup_microbatches
=
\
(
mpu
.
get_pipeline_model_parallel_world_size
()
-
mpu
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
# Barrier before first receive to measure forward stall.
if
i
==
(
num_warmup_microbatches
-
1
):
timers
(
'forward-pipeline-stall'
).
start
()
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_pipeline_model_parallel_group
())
timers
(
'forward-pipeline-stall'
).
stop
()
p2p_communication
.
send_forward
(
output_tensor
,
timers
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
# Barrier before first receive to measure forward stall.
if
num_warmup_microbatches
==
0
:
timers
(
'forward-pipeline-stall'
).
start
()
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_pipeline_model_parallel_group
())
timers
(
'forward-pipeline-stall'
).
stop
()
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
if
forward_only
:
p2p_communication
.
send_forward
(
output_tensor
,
timers
)
else
:
output_tensor_grad
=
\
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
timers
)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
if
forward_only
:
if
not
last_iteration
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
else
:
input_tensor
,
output_tensor
=
input_tensors
.
pop
(
0
),
output_tensors
.
pop
(
0
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
if
last_iteration
:
input_tensor
=
None
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
)
else
:
input_tensor
=
\
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
timers
)
# Run cooldown backward passes.
if
not
forward_only
:
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
p2p_communication
.
recv_backward
(
timers
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
)
return
losses_reduced
megatron/training.py
View file @
b2f57fc4
...
@@ -46,8 +46,12 @@ from megatron.learning_rates import AnnealingLR
...
@@ -46,8 +46,12 @@ from megatron.learning_rates import AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
unwrap_model
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.utils
import
calc_params_l2_norm
from
megatron.utils
import
calc_params_l2_norm
from
megatron.schedules
import
forward_backward_no_pipelining
from
megatron.schedules
import
forward_backward_pipelining_without_interleaving
from
megatron.schedules
import
forward_backward_pipelining_with_interleaving
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
...
@@ -107,23 +111,32 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -107,23 +111,32 @@ def pretrain(train_valid_test_dataset_provider,
timers
=
get_timers
()
timers
=
get_timers
()
# Model, optimizer, and learning rate.
# Model, optimizer, and learning rate.
timers
(
'model
and
optimizer'
).
start
()
timers
(
'model
-
and
-
optimizer
-setup
'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
'model
and
optimizer'
).
stop
()
timers
(
'model
-
and
-
optimizer
-setup
'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
print_datetime
(
'after model, optimizer, and learning rate '
'scheduler are built'
)
'scheduler are built'
)
# Data stuff.
# Data stuff.
timers
(
'train/valid/test data iterators'
).
start
()
timers
(
'train/valid/test-data-iterators-setup'
).
start
()
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
=
build_train_valid_test_data_iterators
(
all_data_iterators
=
[
train_valid_test_dataset_provider
)
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
timers
(
'train/valid/test data iterators'
).
stop
()
for
_
in
range
(
len
(
model
))
]
train_data_iterator
=
[
data_iterators
[
0
]
for
data_iterators
in
all_data_iterators
]
valid_data_iterator
=
[
data_iterators
[
1
]
for
data_iterators
in
all_data_iterators
]
test_data_iterator
=
[
data_iterators
[
2
]
for
data_iterators
in
all_data_iterators
]
else
:
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
=
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
timers
(
'train/valid/test-data-iterators-setup'
).
stop
()
print_datetime
(
'after dataloaders are built'
)
print_datetime
(
'after dataloaders are built'
)
# Print setup timing.
# Print setup timing.
print_rank_0
(
'done with setup
s
...'
)
print_rank_0
(
'done with setup ...'
)
timers
.
log
([
'model
and
optimizer'
,
'train/valid/test
data
iterators'
])
timers
.
log
([
'model
-
and
-
optimizer
-setup
'
,
'train/valid/test
-
data
-
iterators
-setup
'
])
print_rank_0
(
'training ...'
)
print_rank_0
(
'training ...'
)
iteration
=
0
iteration
=
0
...
@@ -185,13 +198,16 @@ def get_model(model_provider_func):
...
@@ -185,13 +198,16 @@ def get_model(model_provider_func):
# Build model on cpu.
# Build model on cpu.
model
=
model_provider_func
()
model
=
model_provider_func
()
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Set tensor model parallel attributes if not set.
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
# are set for all params so the optimizer can use them.
for
param
in
model
.
parameters
():
for
model_module
in
model
:
mpu
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
for
param
in
model_module
.
parameters
():
mpu
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
...
@@ -199,22 +215,25 @@ def get_model(model_provider_func):
...
@@ -199,22 +215,25 @@ def get_model(model_provider_func):
'model parallel rank ({}, {}): {}'
.
format
(
'model parallel rank ({}, {}): {}'
.
format
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])),
flush
=
True
)
sum
([
sum
([
p
.
nelement
()
for
p
in
model_module
.
parameters
()])
for
model_module
in
model
])),
flush
=
True
)
# GPU allocation.
# GPU allocation.
model
.
cuda
(
torch
.
cuda
.
current_device
())
for
model_module
in
model
:
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
# Fp16 conversion.
# Fp16 conversion.
if
args
.
fp16
:
if
args
.
fp16
:
model
=
FP16Module
(
model
)
model
=
[
FP16Module
(
model
_module
)
for
model_module
in
model
]
if
args
.
DDP_impl
==
'torch'
:
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
i
=
torch
.
cuda
.
current_device
()
model
=
torchDDP
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
model
=
[
torchDDP
(
model_module
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
process_group
=
mpu
.
get_data_parallel_group
())
for
model_module
in
model
]
return
model
return
model
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
model
=
LocalDDP
(
model
)
model
=
[
LocalDDP
(
model
_module
)
for
model_module
in
model
]
return
model
return
model
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
...
@@ -270,9 +289,8 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -270,9 +289,8 @@ def setup_model_and_optimizer(model_provider_func):
model
=
get_model
(
model_provider_func
)
model
=
get_model
(
model_provider_func
)
unwrapped_model
=
model
unwrapped_model
=
unwrap_model
(
model
,
while
isinstance
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16Module
)):
(
torchDDP
,
LocalDDP
,
FP16Module
))
unwrapped_model
=
unwrapped_model
.
module
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
...
@@ -282,305 +300,35 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -282,305 +300,35 @@ def setup_model_and_optimizer(model_provider_func):
# Extra barrier is added to make sure all ranks report the
# Extra barrier is added to make sure all ranks report the
# max time.
# max time.
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'load
checkpoint'
).
start
()
timers
(
'load
-
checkpoint'
).
start
()
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'load
checkpoint'
).
stop
()
timers
(
'load
-
checkpoint'
).
stop
()
timers
.
log
([
'load
checkpoint'
])
timers
.
log
([
'load
-
checkpoint'
])
else
:
else
:
args
.
iteration
=
0
args
.
iteration
=
0
# We only support local DDP with multiple micro-batches.
# We only support local DDP with multiple micro-batches.
if
get_num_microbatches
()
>
1
:
if
get_num_microbatches
()
>
1
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
DDP_impl
==
'local'
if
len
(
model
)
==
1
:
assert
args
.
DDP_impl
==
'local'
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
assert
args
.
DDP_impl
==
'local'
# get model without FP16 and/or TorchDDP wrappers
# get model without FP16 and/or TorchDDP wrappers
unwrapped_model
=
model
model
=
unwrap_model
(
model
)
while
hasattr
(
unwrapped_model
,
'module'
):
for
module
in
model
:
unwrapped_model
=
unwrapped_model
.
module
if
args
.
iteration
==
0
and
hasattr
(
module
,
'init_state_dict_from_bert'
):
if
args
.
iteration
==
0
and
hasattr
(
unwrapped_model
,
print
(
"Initializing ICT from pretrained BERT model"
,
flush
=
True
)
'init_state_dict_from_bert'
):
module
.
init_state_dict_from_bert
()
print_rank_0
(
"Initializing ICT from pretrained BERT model"
)
if
args
.
fp16
:
unwrapped_model
.
init_state_dict_from_bert
()
optimizer
.
reload_model_params
()
if
args
.
fp16
:
optimizer
.
reload_model_params
()
return
model
,
optimizer
,
lr_scheduler
return
model
,
optimizer
,
lr_scheduler
def
communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_forward
,
recv_backward
):
"""Communicate tensors between stages."""
args
=
get_args
()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
dtype
=
args
.
params_dtype
if
args
.
fp32_residual_connection
:
dtype
=
torch
.
float
if
recv_forward
:
tensor_recv_prev
=
torch
.
empty
(
tensor_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
if
recv_backward
:
tensor_recv_next
=
torch
.
empty
(
tensor_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
# Send tensors in both the forward and backward directions as appropriate.
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
return
tensor_recv_prev
,
tensor_recv_next
def
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
"""Backward step."""
args
=
get_args
()
timers
=
get_timers
()
# Retain the grad on the input_tensor.
if
input_tensor
is
not
None
:
input_tensor
.
retain_grad
()
# Backward pass.
if
output_tensor_grad
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
# Collect the grad of the input_tensor.
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
return
input_tensor_grad
def
forward_step_with_communication
(
forward_step_func
,
data_iterator
,
model
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
):
args
=
get_args
()
if
not
mpu
.
is_pipeline_first_stage
():
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
timers
(
'forward-recv'
).
stop
()
else
:
input_tensor
=
None
# Forward model for one step.
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
timers
(
'forward-compute'
).
stop
()
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
else
:
timers
(
'forward-send'
).
start
()
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
timers
(
'forward-send'
).
stop
()
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
def
backward_step_with_communication
(
optimizer
,
model
,
input_tensors
,
output_tensors
,
timers
):
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
timers
(
'backward-recv'
).
start
()
_
,
output_tensor_grad
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
True
)
timers
(
'backward-recv'
).
stop
()
# Backward pass for one step.
timers
(
'backward-compute'
).
start
()
input_grad_tensor
=
\
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
timers
(
'backward-compute'
).
stop
()
if
not
mpu
.
is_pipeline_first_stage
():
timers
(
'backward-send'
).
start
()
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_grad_tensor
,
recv_forward
=
False
,
recv_backward
=
False
)
timers
(
'backward-send'
).
stop
()
def
forward_and_backward_steps_with_communication
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
input_tensor
,
last_microbatch
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
):
args
=
get_args
()
# Forward model for one step.
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
timers
(
'forward-compute'
).
stop
()
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
output_tensor_grad
=
None
losses_reduced
.
append
(
loss_reduced
)
else
:
timers
(
'forward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
True
)
timers
(
'forward-send-backward-recv'
).
stop
()
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
# Backward pass for one step.
timers
(
'backward-compute'
).
start
()
input_grad_tensor
=
\
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
timers
(
'backward-compute'
).
stop
()
if
not
mpu
.
is_pipeline_first_stage
():
timers
(
'backward-send-forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_grad_tensor
,
recv_forward
=
(
not
last_microbatch
),
recv_backward
=
False
)
timers
(
'backward-send-forward-recv'
).
stop
()
else
:
input_tensor
=
None
return
input_tensor
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
):
"""Run forward and backward passes without inter-stage communication."""
args
=
get_args
()
losses_reduced
=
[]
for
i
in
range
(
get_num_microbatches
()):
timers
(
'forward-compute'
).
start
()
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
=
None
)
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
timers
(
'forward-compute'
).
stop
()
timers
(
'backward-compute'
).
start
()
output_tensor_grad
=
None
backward_step
(
optimizer
,
model
,
input_tensor
=
None
,
output_tensor
=
output_tensor
,
output_tensor_grad
=
None
)
timers
(
'backward-compute'
).
stop
()
return
losses_reduced
def
forward_backward_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
args
=
get_args
()
# Compute number of warmup microbatches.
num_microbatches
=
get_num_microbatches
()
num_warmup_microbatches
=
\
(
mpu
.
get_pipeline_model_parallel_world_size
()
-
mpu
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
forward_step_with_communication
(
forward_step_func
,
data_iterator
,
model
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
timers
(
'forward-recv'
).
stop
()
# Run 1F1B.
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
input_tensor
=
\
forward_and_backward_steps_with_communication
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
input_tensor
,
last_iteration
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
)
# Run cooldown backward passes.
for
i
in
range
(
num_warmup_microbatches
):
backward_step_with_communication
(
optimizer
,
model
,
input_tensors
,
output_tensors
,
timers
)
return
losses_reduced
def
train_step
(
forward_step_func
,
data_iterator
,
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr_scheduler
):
model
,
optimizer
,
lr_scheduler
):
"""Single training step."""
"""Single training step."""
...
@@ -591,29 +339,43 @@ def train_step(forward_step_func, data_iterator,
...
@@ -591,29 +339,43 @@ def train_step(forward_step_func, data_iterator,
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
losses_reduced
=
forward_backward_pipelining
(
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
)
forward_backward_func
=
forward_backward_pipelining_with_interleaving
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
else
:
losses_reduced
=
forward_backward_no_pipelining
(
forward_backward_func
=
forward_backward_no_pipelining
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
)
losses_reduced
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
=
False
)
# All-reduce if needed.
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
timers
(
'backward-params-all-reduce'
).
start
()
model
.
allreduce_params
(
reduce_after
=
False
,
for
model_module
in
model
:
fp32_allreduce
=
args
.
fp32_allreduce
)
model_module
.
allreduce_params
(
reduce_after
=
False
,
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'backward-params-all-reduce'
).
stop
()
timers
(
'backward-params-all-reduce'
).
stop
()
# Barrier to measure backward stall.
timers
(
'backward-pipeline-stall'
).
start
()
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_pipeline_model_parallel_group
())
timers
(
'backward-pipeline-stall'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
timers
(
'backward-embedding-all-reduce'
).
start
()
if
(
mpu
.
is_pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
())
and
\
if
(
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
or
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
))
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
unwrapped_model
=
model
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
while
isinstance
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16Module
)):
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrapped_model
.
module
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16Module
))
if
unwrapped_model
.
share_word_embeddings
:
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
...
@@ -623,11 +385,11 @@ def train_step(forward_step_func, data_iterator,
...
@@ -623,11 +385,11 @@ def train_step(forward_step_func, data_iterator,
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
update_successful
l
,
grad_norm
=
optimizer
.
step
()
update_successful
,
grad_norm
=
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
timers
(
'optimizer'
).
stop
()
# Update learning rate.
# Update learning rate.
if
update_successful
l
:
if
update_successful
:
increment
=
get_num_microbatches
()
*
\
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
args
.
data_parallel_size
...
@@ -636,7 +398,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -636,7 +398,7 @@ def train_step(forward_step_func, data_iterator,
else
:
else
:
skipped_iter
=
1
skipped_iter
=
1
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Average loss across microbatches.
# Average loss across microbatches.
loss_reduced
=
{}
loss_reduced
=
{}
for
key
in
losses_reduced
[
0
]:
for
key
in
losses_reduced
[
0
]:
...
@@ -690,13 +452,16 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -690,13 +452,16 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if
name
in
timers
.
timers
:
if
name
in
timers
.
timers
:
timers_to_log
.
append
(
name
)
timers_to_log
.
append
(
name
)
add_to_logging
(
'forward-compute'
)
add_to_logging
(
'forward-compute'
)
add_to_logging
(
'forward-pipeline-stall'
)
add_to_logging
(
'forward-recv'
)
add_to_logging
(
'forward-recv'
)
add_to_logging
(
'forward-send'
)
add_to_logging
(
'forward-send'
)
add_to_logging
(
'forward-
sen
d-backward-recv'
)
add_to_logging
(
'forward-
backward-send-forwar
d-backward-recv'
)
add_to_logging
(
'backward-compute'
)
add_to_logging
(
'backward-compute'
)
add_to_logging
(
'backward-pipeline-stall'
)
add_to_logging
(
'backward-recv'
)
add_to_logging
(
'backward-recv'
)
add_to_logging
(
'backward-send'
)
add_to_logging
(
'backward-send'
)
add_to_logging
(
'backward-send-forward-recv'
)
add_to_logging
(
'backward-send-forward-recv'
)
add_to_logging
(
'backward-send-backward-recv'
)
add_to_logging
(
'backward-params-all-reduce'
)
add_to_logging
(
'backward-params-all-reduce'
)
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'optimizer-copy-to-main-grad'
)
add_to_logging
(
'optimizer-copy-to-main-grad'
)
...
@@ -745,7 +510,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -745,7 +510,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
normalizer
=
total_iterations
)
normalizer
=
total_iterations
)
if
iteration
%
args
.
log_interval
==
0
:
if
iteration
%
args
.
log_interval
==
0
:
elapsed_time
=
timers
(
'interval
time'
).
elapsed
()
elapsed_time
=
timers
(
'interval
-
time'
).
elapsed
()
elapsed_time_per_iteration
=
elapsed_time
/
total_iterations
elapsed_time_per_iteration
=
elapsed_time
/
total_iterations
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
if
args
.
log_timers_to_tensorboard
:
if
args
.
log_timers_to_tensorboard
:
...
@@ -794,11 +559,11 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
...
@@ -794,11 +559,11 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
# Extra barrier is added to make sure
# Extra barrier is added to make sure
# all ranks report the max time.
# all ranks report the max time.
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'save
checkpoint'
).
start
()
timers
(
'save
-
checkpoint'
).
start
()
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'save
checkpoint'
).
stop
()
timers
(
'save
-
checkpoint'
).
stop
()
timers
.
log
([
'save
checkpoint'
])
timers
.
log
([
'save
-
checkpoint'
])
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
...
@@ -811,7 +576,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -811,7 +576,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
write_args_to_tensorboard
()
write_args_to_tensorboard
()
# Turn on training mode which enables dropout.
# Turn on training mode which enables dropout.
model
.
train
()
for
model_module
in
model
:
model_module
.
train
()
# Tracking loss.
# Tracking loss.
total_loss_dict
=
{}
total_loss_dict
=
{}
...
@@ -819,7 +585,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -819,7 +585,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
# Iterations.
iteration
=
args
.
iteration
iteration
=
args
.
iteration
timers
(
'interval
time'
).
start
()
timers
(
'interval
-
time'
).
start
()
print_datetime
(
'before the start of training step'
)
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
...
@@ -900,7 +666,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -900,7 +666,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args
=
get_args
()
args
=
get_args
()
# Turn on evaluation mode which disables dropout.
# Turn on evaluation mode which disables dropout.
model
.
eval
()
for
model_module
in
model
:
model_module
.
eval
()
total_loss_dict
=
{}
total_loss_dict
=
{}
...
@@ -912,37 +679,30 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -912,37 +679,30 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
args
.
eval_iters
))
args
.
eval_iters
))
for
_
in
range
(
get_num_microbatches
()):
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
not
mpu
.
is_pipeline_first_stage
():
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
input_tensor
,
_
=
communicate
(
forward_backward_func
=
forward_backward_pipelining_with_interleaving
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
else
:
input_tensor
=
None
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
# Forward evaluation.
forward_backward_func
=
forward_backward_no_pipelining
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
loss_dicts
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
=
None
,
if
mpu
.
is_pipeline_last_stage
():
timers
=
None
,
forward_only
=
True
)
_
,
loss_dict
=
output_tensor
# Reduce across processes.
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Reduce across processes.
for
loss_dict
in
loss_dicts
:
for
key
in
loss_dict
:
for
key
in
loss_dict
:
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
\
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
loss_dict
[
key
]
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
loss_dict
[
key
]
else
:
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
micro_batch_size
\
*
args
.
micro_batch_size
\
*
get_num_microbatches
()
*
get_num_microbatches
()
# Move model back to the train mode.
# Move model back to the train mode.
model
.
train
()
for
model_module
in
model
:
model_module
.
train
()
for
key
in
total_loss_dict
:
for
key
in
total_loss_dict
:
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get_num_microbatches
()
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get_num_microbatches
()
...
...
megatron/utils.py
View file @
b2f57fc4
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
import
sys
import
sys
import
torch
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
torchDDP
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
import
amp_C
...
@@ -26,11 +27,25 @@ from megatron import get_args
...
@@ -26,11 +27,25 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model.module
import
param_is_not_shared
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
def
unwrap_model
(
model
,
module_instances
=
(
torchDDP
)):
return_list
=
True
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
return_list
=
False
unwrapped_model
=
[]
for
model_module
in
model
:
while
isinstance
(
model_module
,
module_instances
):
model_module
=
model_module
.
module
unwrapped_model
.
append
(
model_module
)
if
not
return_list
:
return
unwrapped_model
[
0
]
return
unwrapped_model
def
calc_params_l2_norm
(
model
):
def
calc_params_l2_norm
(
model
):
"""Calculate l2 norm of parameters """
"""Calculate l2 norm of parameters """
# Remove duplicate params.
# Remove duplicate params.
...
@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration):
...
@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration):
def
check_adlr_autoresume_termination
(
iteration
,
model
,
def
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
optimizer
,
lr_scheduler
):
"""Check for autoresume signal and exit if it is received."""
"""Check for autoresume signal and exit if it is received."""
from
megatron.checkpointing
import
save_checkpoint
args
=
get_args
()
args
=
get_args
()
autoresume
=
get_adlr_autoresume
()
autoresume
=
get_adlr_autoresume
()
# Add barrier to ensure consistnecy.
# Add barrier to ensure consistnecy.
...
...
pretrain_bert.py
View file @
b2f57fc4
...
@@ -38,7 +38,7 @@ def model_provider():
...
@@ -38,7 +38,7 @@ def model_provider():
args
=
get_args
()
args
=
get_args
()
num_tokentypes
=
2
if
args
.
bert_binary_head
else
0
num_tokentypes
=
2
if
args
.
bert_binary_head
else
0
i
f
m
pu
.
get_pipeline_model_parallel_world_size
()
>
1
:
de
f
m
odel_provider_pipelined
()
:
# Determine model based on position of stage in pipeline.
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
model
=
BertModelFirstStage
(
model
=
BertModelFirstStage
(
...
@@ -51,6 +51,17 @@ def model_provider():
...
@@ -51,6 +51,17 @@ def model_provider():
else
:
else
:
model
=
BertModelIntermediateStage
(
model
=
BertModelIntermediateStage
(
num_tokentypes
=
num_tokentypes
)
num_tokentypes
=
num_tokentypes
)
return
model
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
.
append
(
model_provider_pipelined
())
else
:
model
=
model_provider_pipelined
()
else
:
else
:
model
=
BertModel
(
model
=
BertModel
(
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
...
@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch.
# Get the batch.
timers
(
'batch-generator'
).
start
()
timers
(
'batch-generator'
).
start
()
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
\
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
=
get_batch
(
=
get_batch
(
data_iterator
)
data_iterator
)
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
if
not
args
.
bert_binary_head
:
if
not
args
.
bert_binary_head
:
...
...
pretrain_gpt.py
View file @
b2f57fc4
...
@@ -35,8 +35,8 @@ def model_provider():
...
@@ -35,8 +35,8 @@ def model_provider():
"""Build the model."""
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
print_rank_0
(
'building GPT model ...'
)
args
=
get_args
()
i
f
m
pu
.
get_pipeline_model_parallel_world_size
()
>
1
:
de
f
m
odel_provider_pipelined
()
:
# Determine model based on position of stage in pipeline.
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
model
=
GPTModelFirstStage
(
num_tokentypes
=
0
)
model
=
GPTModelFirstStage
(
num_tokentypes
=
0
)
...
@@ -46,6 +46,17 @@ def model_provider():
...
@@ -46,6 +46,17 @@ def model_provider():
else
:
else
:
model
=
GPTModelIntermediateStage
(
model
=
GPTModelIntermediateStage
(
num_tokentypes
=
0
)
num_tokentypes
=
0
)
return
model
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
.
append
(
model_provider_pipelined
())
else
:
model
=
model_provider_pipelined
()
else
:
else
:
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
)
...
...
tools/generate_samples_gpt.py
View file @
b2f57fc4
...
@@ -92,7 +92,9 @@ def main():
...
@@ -92,7 +92,9 @@ def main():
"""Main program."""
"""Main program."""
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
'no_load_rng'
:
True
,
'no_load_optim'
:
True
})
# Set up model and load checkpoint.
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
model
=
get_model
(
model_provider
)
...
...
tools/merge_mp_partitions.py
View file @
b2f57fc4
...
@@ -200,6 +200,8 @@ def main():
...
@@ -200,6 +200,8 @@ def main():
'micro_batch_size'
:
1
,
'micro_batch_size'
:
1
,
'no_load_optim'
:
True
,
'no_load_optim'
:
True
,
'no_load_rng'
:
True
,
'no_load_rng'
:
True
,
'no_save_optim'
:
True
,
'no_save_rng'
:
True
,
'save_interval'
:
1
})
'save_interval'
:
1
})
args
=
get_args
()
args
=
get_args
()
...
@@ -240,6 +242,11 @@ def main():
...
@@ -240,6 +242,11 @@ def main():
tokenizer
=
rebuild_tokenizer
(
args
)
tokenizer
=
rebuild_tokenizer
(
args
)
mpu
.
initialize
.
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
mpu
.
initialize
.
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
for
rank
in
range
(
args
.
tensor_model_parallel_size
):
for
rank
in
range
(
args
.
tensor_model_parallel_size
):
# Reset these since load_checkpoint asserts they are 0, but we are loading
# multiple checkpoints in the same process and they get set each time
args
.
consumed_train_samples
=
0
args
.
consumed_valid_samples
=
0
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
rank
)
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
rank
)
checkpoint_name
,
iteration
=
get_parallel_checkpoint_name
(
args
.
load
)
checkpoint_name
,
iteration
=
get_parallel_checkpoint_name
(
args
.
load
)
model_
=
get_model
(
model_type
)
model_
=
get_model
(
model_type
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment