Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b2f57fc4
Commit
b2f57fc4
authored
Feb 17, 2021
by
Mostofa Patwary
Browse files
pulled latest megatron
parents
a4b628ab
76e3fca0
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1019 additions
and
398 deletions
+1019
-398
megatron/arguments.py
megatron/arguments.py
+37
-4
megatron/checkpointing.py
megatron/checkpointing.py
+23
-12
megatron/initialize.py
megatron/initialize.py
+2
-1
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+10
-5
megatron/model/module.py
megatron/model/module.py
+2
-2
megatron/model/transformer.py
megatron/model/transformer.py
+21
-1
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+3
-0
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+46
-3
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+14
-13
megatron/p2p_communication.py
megatron/p2p_communication.py
+272
-0
megatron/schedules.py
megatron/schedules.py
+424
-0
megatron/training.py
megatron/training.py
+110
-350
megatron/utils.py
megatron/utils.py
+18
-1
pretrain_bert.py
pretrain_bert.py
+14
-3
pretrain_gpt.py
pretrain_gpt.py
+13
-2
tools/generate_samples_gpt.py
tools/generate_samples_gpt.py
+3
-1
tools/merge_mp_partitions.py
tools/merge_mp_partitions.py
+7
-0
No files found.
megatron/arguments.py
View file @
b2f57fc4
...
...
@@ -70,7 +70,7 @@ def parse_args(extra_args_provider=None, defaults={},
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
assert
args
.
world_size
%
model_parallel_size
==
0
,
'world size is not'
\
' divisible by tensor parallel size ({}) times pipeline paralle '
\
' divisible by tensor parallel size ({}) times pipeline paralle
l
'
\
'size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
data_parallel_size
=
args
.
world_size
//
model_parallel_size
...
...
@@ -116,6 +116,18 @@ def parse_args(extra_args_provider=None, defaults={},
print
(
'setting global batch size to {}'
.
format
(
args
.
global_batch_size
),
flush
=
True
)
assert
args
.
global_batch_size
>
0
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
assert
args
.
num_layers
%
args
.
num_layers_per_virtual_pipeline_stage
==
0
,
\
'number of layers is not divisible by number of layers per virtual '
\
'pipeline stage'
args
.
virtual_pipeline_model_parallel_size
=
\
(
args
.
num_layers
//
args
.
pipeline_model_parallel_size
)
//
\
args
.
num_layers_per_virtual_pipeline_stage
assert
args
.
global_batch_size
%
args
.
pipeline_model_parallel_size
==
0
,
\
'global batch size is not divisible by pipeline parallel size when '
\
'using interleaved schedule'
else
:
args
.
virtual_pipeline_model_parallel_size
=
None
# Parameters dtype.
args
.
params_dtype
=
torch
.
float
...
...
@@ -202,7 +214,23 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
checkpoint_activations
,
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
# custom kernel constraints check
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
# constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
if
args
.
fp16
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
:
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.'
)
# Load scaled_masked_softmax_fusion_kernels
if
args
.
masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
...
...
@@ -478,9 +506,9 @@ def _add_checkpointing_args(parser):
help
=
'Output directory to save checkpoints to.'
)
group
.
add_argument
(
'--save-interval'
,
type
=
int
,
default
=
None
,
help
=
'Number of iterations between checkpoint saves.'
)
group
.
add_argument
(
'--no-save-optim'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-save-optim'
,
action
=
'store_true'
,
default
=
None
,
help
=
'Do not save current optimizer.'
)
group
.
add_argument
(
'--no-save-rng'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-save-rng'
,
action
=
'store_true'
,
default
=
None
,
help
=
'Do not save current rng state.'
)
group
.
add_argument
(
'--load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing a model checkpoint.'
)
...
...
@@ -541,6 +569,8 @@ def _add_distributed_args(parser):
group
.
add_argument
(
'--model-parallel-size'
,
type
=
int
,
default
=
None
,
help
=
'Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.'
)
group
.
add_argument
(
'--num-layers-per-virtual-pipeline-stage'
,
type
=
int
,
default
=
None
,
help
=
'Number of layers per virtual pipeline stage'
)
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
choices
=
[
'nccl'
,
'gloo'
],
help
=
'Which backend to use for distributed training.'
)
...
...
@@ -548,6 +578,9 @@ def _add_distributed_args(parser):
choices
=
[
'local'
,
'torch'
],
help
=
'which DistributedDataParallel implementation '
'to use.'
)
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
group
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
None
,
help
=
'local rank passed from distributed launcher.'
)
group
.
add_argument
(
'--lazy-mpu-init'
,
type
=
bool
,
required
=
False
,
...
...
megatron/checkpointing.py
View file @
b2f57fc4
...
...
@@ -21,12 +21,12 @@ import sys
import
numpy
as
np
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
torchDDP
from
megatron
import
(
get_args
,
mpu
,
print_rank_0
,
update_num_microbatches
)
update_num_microbatches
,
utils
)
_CHECKPOINT_VERSION
=
None
...
...
@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args
=
get_args
()
# Only rank zero of the data parallel writes to the disk.
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
model
=
utils
.
unwrap_model
(
model
)
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
))
...
...
@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
3.0
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
if
len
(
model
)
==
1
:
state_dict
[
'model'
]
=
model
[
0
].
state_dict_for_save_checkpoint
()
else
:
for
i
in
range
(
len
(
model
)):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
state_dict
[
'model%d'
%
i
]
=
model
[
i
].
state_dict_for_save_checkpoint
()
# Optimizer stuff.
if
not
args
.
no_save_optim
:
...
...
@@ -238,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
args
=
get_args
()
load_dir
=
getattr
(
args
,
load_arg
)
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
model
=
utils
.
unwrap_model
(
model
)
# Read the tracker file and set the iteration.
tracker_filename
=
get_checkpoint_tracker_filename
(
load_dir
)
...
...
@@ -324,7 +328,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0
(
'could not find arguments in the checkpoint ...'
)
# Model.
model
.
load_state_dict
(
state_dict
[
'model'
],
strict
=
strict
)
if
len
(
model
)
==
1
:
model
[
0
].
load_state_dict
(
state_dict
[
'model'
],
strict
=
strict
)
else
:
for
i
in
range
(
len
(
model
)):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
[
i
].
load_state_dict
(
state_dict
[
'model%d'
%
i
],
strict
=
strict
)
# Fix up query/key/value matrix ordering if needed
checkpoint_version
=
get_checkpoint_version
()
...
...
@@ -352,12 +361,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
np
.
random
.
set_state
(
state_dict
[
'np_rng_state'
])
torch
.
set_rng_state
(
state_dict
[
'torch_rng_state'
])
torch
.
cuda
.
set_rng_state
(
state_dict
[
'cuda_rng_state'
])
# Check for empty states array
if
not
state_dict
[
'rng_tracker_states'
]:
raise
KeyError
mpu
.
get_cuda_rng_tracker
().
set_states
(
state_dict
[
'rng_tracker_states'
])
except
KeyError
:
print_rank_0
(
'Unable to load
optimizer
from checkpoint {}. '
print_rank_0
(
'Unable to load
rng state
from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the
optimizer
state, '
'attempting to load the
rng
state, '
'exiting ...'
.
format
(
checkpoint_name
))
sys
.
exit
()
...
...
@@ -376,8 +388,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f
args
=
get_args
()
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
model
=
utils
.
unwrap_model
(
model
)
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
...
...
megatron/initialize.py
View file @
b2f57fc4
...
...
@@ -133,7 +133,8 @@ def _initialize_distributed():
print
(
'model parallel is already initialized'
)
else
:
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
)
def
_init_autoresume
():
...
...
megatron/model/fused_softmax.py
View file @
b2f57fc4
...
...
@@ -113,18 +113,23 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
(
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
data_size
=
input
.
size
()
query_seq_len
=
data_size
[
-
2
]
key_seq_len
=
data_size
[
-
1
]
a
ssert
input
.
dim
()
==
4
a
ttn_batch_size
=
data_size
[
0
]
*
data_size
[
1
]
# invoke custom kernel
if
self
.
input_in_fp16
and
key_seq_len
<=
2048
and
mask
is
not
None
and
\
query_seq_len
%
4
==
0
and
self
.
scaled_masked_softmax_fusion
:
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
key_seq_len
>
16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
if
self
.
input_in_fp16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
...
...
megatron/model/module.py
View file @
b2f57fc4
...
...
@@ -50,9 +50,9 @@ class MegatronModule(torch.nn.Module):
def
word_embeddings_weight
(
self
):
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
not
self
.
share_word_embeddings
:
raise
Exception
(
'word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false'
)
...
...
megatron/model/transformer.py
View file @
b2f57fc4
...
...
@@ -552,7 +552,27 @@ class ParallelTransformer(MegatronModule):
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
...
...
megatron/mpu/__init__.py
View file @
b2f57fc4
...
...
@@ -38,6 +38,7 @@ from .initialize import get_pipeline_model_parallel_next_rank
from
.initialize
import
get_pipeline_model_parallel_prev_rank
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
get_virtual_pipeline_model_parallel_rank
,
set_virtual_pipeline_model_parallel_rank
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
...
...
@@ -58,6 +59,8 @@ from .random import get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
gather_split_1d_tensor
from
.random
import
split_tensor_into_1d_equal_chunks
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
megatron/mpu/initialize.py
View file @
b2f57fc4
...
...
@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
...
...
@@ -48,7 +51,8 @@ def is_unitialized():
def
initialize_model_parallel
(
tensor_model_parallel_size_
=
1
,
pipeline_model_parallel_size_
=
1
):
pipeline_model_parallel_size_
=
1
,
virtual_pipeline_model_parallel_size_
=
None
):
"""
Initialize model data parallel groups.
...
...
@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size_
is
not
None
:
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
virtual_pipeline_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
# Build the data-parallel groups.
...
...
@@ -258,17 +268,46 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
def
is_pipeline_first_stage
():
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
and
\
get_virtual_pipeline_model_parallel_rank
()
!=
0
:
return
False
return
get_pipeline_model_parallel_rank
()
==
0
def
is_pipeline_last_stage
():
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
virtual_pipeline_model_parallel_world_size
=
\
get_virtual_pipeline_model_parallel_world_size
()
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
\
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
):
return
False
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
def
get_virtual_pipeline_model_parallel_rank
():
"""Return the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def
set_virtual_pipeline_model_parallel_rank
(
rank
):
"""Set the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
rank
def
get_virtual_pipeline_model_parallel_world_size
():
"""Return the virtual pipeline-parallel world size."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
...
...
@@ -276,11 +315,13 @@ def get_tensor_model_parallel_src_rank():
local_world_size
=
get_tensor_model_parallel_world_size
()
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
...
...
@@ -294,6 +335,7 @@ def get_pipeline_model_parallel_next_rank():
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
def
get_pipeline_model_parallel_prev_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
...
...
@@ -301,6 +343,7 @@ def get_pipeline_model_parallel_prev_rank():
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
def
get_data_parallel_world_size
():
"""Return world size for the data parallel group."""
return
torch
.
distributed
.
get_world_size
(
group
=
get_data_parallel_group
())
...
...
megatron/optimizer/__init__.py
View file @
b2f57fc4
...
...
@@ -23,7 +23,7 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from
.optimizer
import
FP16OptimizerWithFP16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
module
):
def
_get_params_for_weight_decay_optimization
(
module
s
):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
...
...
@@ -32,18 +32,19 @@ def _get_params_for_weight_decay_optimization(module):
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
for
module
in
modules
:
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
...
...
megatron/p2p_communication.py
0 → 100644
View file @
b2f57fc4
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
reduce
import
operator
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
use_ring_exchange
=
False
):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
args
=
get_args
()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
args
.
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
\
mpu
.
get_tensor_model_parallel_world_size
()
else
:
tensor_chunk_shape
=
tensor_shape
dtype
=
args
.
params_dtype
if
args
.
fp32_residual_connection
:
dtype
=
torch
.
float
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
if
recv_next
:
tensor_recv_next
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
if
tensor_send_prev
is
not
None
:
tensor_send_prev
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
if
use_ring_exchange
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_recv_next
=
tensor_recv_next
,
group
=
mpu
.
get_pipeline_model_parallel_group
())
else
:
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
# If using scatter-gather optimization, gather smaller chunks.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
if
recv_next
:
tensor_recv_next
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
timers
=
None
,
use_ring_exchange
=
False
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
if
timers
is
not
None
:
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_next
=
False
,
use_ring_exchange
=
use_ring_exchange
)
if
timers
is
not
None
:
timers
(
'forward-recv'
).
stop
()
return
input_tensor
def
recv_backward
(
timers
=
None
,
use_ring_exchange
=
False
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
if
timers
is
not
None
:
timers
(
'backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
use_ring_exchange
=
use_ring_exchange
)
if
timers
is
not
None
:
timers
(
'backward-recv'
).
stop
()
return
output_tensor_grad
def
send_forward
(
output_tensor
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
timers
(
'forward-send'
).
start
()
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
,
use_ring_exchange
=
use_ring_exchange
)
if
timers
is
not
None
:
timers
(
'forward-send'
).
stop
()
def
send_backward
(
input_tensor_grad
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
timers
(
'backward-send'
).
start
()
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
,
use_ring_exchange
=
use_ring_exchange
)
if
timers
is
not
None
:
timers
(
'backward-send'
).
stop
()
def
send_forward_recv_backward
(
output_tensor
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Batched send and recv with next rank in pipeline."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
,
use_ring_exchange
=
use_ring_exchange
)
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
stop
()
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
None
,
use_ring_exchange
=
False
):
"""Batched send and recv with previous rank in pipeline."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
,
use_ring_exchange
=
use_ring_exchange
)
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
stop
()
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
=
None
):
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
,
use_ring_exchange
=
True
)
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
stop
()
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
=
None
):
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
,
use_ring_exchange
=
True
)
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
stop
()
return
output_tensor_grad
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
,
timers
=
None
):
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
input_tensor
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
use_ring_exchange
=
True
)
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
return
input_tensor
,
output_tensor_grad
megatron/schedules.py
0 → 100644
View file @
b2f57fc4
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
megatron
import
get_args
from
megatron
import
get_num_microbatches
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
p2p_communication
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
timers
=
get_timers
()
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
timers
(
'forward-compute'
).
stop
()
return
output_tensor
def
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
args
=
get_args
()
timers
=
get_timers
()
timers
(
'backward-compute'
).
start
()
# Retain the grad on the input_tensor.
if
input_tensor
is
not
None
:
input_tensor
.
retain_grad
()
# Backward pass.
if
output_tensor_grad
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
# Collect the grad of the input_tensor.
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
timers
(
'backward-compute'
).
stop
()
return
input_tensor_grad
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
assert
len
(
model
)
==
1
model
=
model
[
0
]
losses_reduced
=
[]
for
i
in
range
(
get_num_microbatches
()):
input_tensor
,
output_tensor_grad
=
None
,
None
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
losses_reduced
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
losses_reduced
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
all_warmup_microbatches
=
False
if
forward_only
:
num_warmup_microbatches
=
num_microbatches
else
:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if
get_num_microbatches
()
==
pipeline_parallel_size
:
num_warmup_microbatches
=
num_microbatches
all_warmup_microbatches
=
True
else
:
num_warmup_microbatches
=
\
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
def
get_model_chunk_id
(
microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
if
not
forward
:
model_chunk_id
=
(
num_model_chunks
-
model_chunk_id
-
1
)
return
model_chunk_id
def
forward_step_helper
(
microbatch_id
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
\
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
return
output_tensor
def
backward_step_helper
(
microbatch_id
):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_last_stage
():
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
input_tensor_grad
# Run warmup forward passes.
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
timers
,
use_ring_exchange
=
True
))
for
k
in
range
(
num_warmup_microbatches
):
output_tensor
=
forward_step_helper
(
k
)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_forward_model_chunk_id
==
0
:
recv_prev
=
False
if
k
==
(
num_microbatches
-
1
):
recv_prev
=
False
# Don't send tensor downstream if on last stage.
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
\
not
all_warmup_microbatches
:
input_tensor_grad
=
None
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
input_tensor
,
output_tensor_grad
=
\
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
input_tensor
=
\
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
# Run 1F1B in steady state.
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
output_tensor
=
forward_step_helper
(
forward_k
)
# Backward pass.
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
None
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
mpu
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev
=
True
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if
k
==
(
num_microbatches_remaining
-
1
):
recv_prev
=
False
# Communicate tensors.
input_tensor
,
output_tensor_grad
=
\
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
# Run cooldown backward passes (flush out pipeline).
if
not
forward_only
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
timers
,
use_ring_exchange
=
True
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_next
=
False
if
k
==
(
num_microbatches
-
1
):
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
))
return
losses_reduced
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
timers
=
get_timers
()
assert
len
(
model
)
==
1
model
=
model
[
0
]
# Compute number of warmup microbatches.
num_microbatches
=
get_num_microbatches
()
num_warmup_microbatches
=
\
(
mpu
.
get_pipeline_model_parallel_world_size
()
-
mpu
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
# Barrier before first receive to measure forward stall.
if
i
==
(
num_warmup_microbatches
-
1
):
timers
(
'forward-pipeline-stall'
).
start
()
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_pipeline_model_parallel_group
())
timers
(
'forward-pipeline-stall'
).
stop
()
p2p_communication
.
send_forward
(
output_tensor
,
timers
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
# Barrier before first receive to measure forward stall.
if
num_warmup_microbatches
==
0
:
timers
(
'forward-pipeline-stall'
).
start
()
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_pipeline_model_parallel_group
())
timers
(
'forward-pipeline-stall'
).
stop
()
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
if
forward_only
:
p2p_communication
.
send_forward
(
output_tensor
,
timers
)
else
:
output_tensor_grad
=
\
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
timers
)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
if
forward_only
:
if
not
last_iteration
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
else
:
input_tensor
,
output_tensor
=
input_tensors
.
pop
(
0
),
output_tensors
.
pop
(
0
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
if
last_iteration
:
input_tensor
=
None
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
)
else
:
input_tensor
=
\
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
timers
)
# Run cooldown backward passes.
if
not
forward_only
:
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
p2p_communication
.
recv_backward
(
timers
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
)
return
losses_reduced
megatron/training.py
View file @
b2f57fc4
...
...
@@ -46,8 +46,12 @@ from megatron.learning_rates import AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
unwrap_model
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.utils
import
calc_params_l2_norm
from
megatron.schedules
import
forward_backward_no_pipelining
from
megatron.schedules
import
forward_backward_pipelining_without_interleaving
from
megatron.schedules
import
forward_backward_pipelining_with_interleaving
from
megatron.utils
import
report_memory
...
...
@@ -107,23 +111,32 @@ def pretrain(train_valid_test_dataset_provider,
timers
=
get_timers
()
# Model, optimizer, and learning rate.
timers
(
'model
and
optimizer'
).
start
()
timers
(
'model
-
and
-
optimizer
-setup
'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
'model
and
optimizer'
).
stop
()
timers
(
'model
-
and
-
optimizer
-setup
'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
'scheduler are built'
)
# Data stuff.
timers
(
'train/valid/test data iterators'
).
start
()
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
=
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
timers
(
'train/valid/test data iterators'
).
stop
()
timers
(
'train/valid/test-data-iterators-setup'
).
start
()
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
all_data_iterators
=
[
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
for
_
in
range
(
len
(
model
))
]
train_data_iterator
=
[
data_iterators
[
0
]
for
data_iterators
in
all_data_iterators
]
valid_data_iterator
=
[
data_iterators
[
1
]
for
data_iterators
in
all_data_iterators
]
test_data_iterator
=
[
data_iterators
[
2
]
for
data_iterators
in
all_data_iterators
]
else
:
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
=
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
timers
(
'train/valid/test-data-iterators-setup'
).
stop
()
print_datetime
(
'after dataloaders are built'
)
# Print setup timing.
print_rank_0
(
'done with setup
s
...'
)
timers
.
log
([
'model
and
optimizer'
,
'train/valid/test
data
iterators'
])
print_rank_0
(
'done with setup ...'
)
timers
.
log
([
'model
-
and
-
optimizer
-setup
'
,
'train/valid/test
-
data
-
iterators
-setup
'
])
print_rank_0
(
'training ...'
)
iteration
=
0
...
...
@@ -185,13 +198,16 @@ def get_model(model_provider_func):
# Build model on cpu.
model
=
model_provider_func
()
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
for
param
in
model
.
parameters
():
mpu
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
for
model_module
in
model
:
for
param
in
model_module
.
parameters
():
mpu
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
...
...
@@ -199,22 +215,25 @@ def get_model(model_provider_func):
'model parallel rank ({}, {}): {}'
.
format
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])),
flush
=
True
)
sum
([
sum
([
p
.
nelement
()
for
p
in
model_module
.
parameters
()])
for
model_module
in
model
])),
flush
=
True
)
# GPU allocation.
model
.
cuda
(
torch
.
cuda
.
current_device
())
for
model_module
in
model
:
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
# Fp16 conversion.
if
args
.
fp16
:
model
=
FP16Module
(
model
)
model
=
[
FP16Module
(
model
_module
)
for
model_module
in
model
]
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
model
=
torchDDP
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
model
=
[
torchDDP
(
model_module
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
for
model_module
in
model
]
return
model
if
args
.
DDP_impl
==
'local'
:
model
=
LocalDDP
(
model
)
model
=
[
LocalDDP
(
model
_module
)
for
model_module
in
model
]
return
model
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
...
...
@@ -270,9 +289,8 @@ def setup_model_and_optimizer(model_provider_func):
model
=
get_model
(
model_provider_func
)
unwrapped_model
=
model
while
isinstance
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16Module
)):
unwrapped_model
=
unwrapped_model
.
module
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
FP16Module
))
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
...
...
@@ -282,305 +300,35 @@ def setup_model_and_optimizer(model_provider_func):
# Extra barrier is added to make sure all ranks report the
# max time.
torch
.
distributed
.
barrier
()
timers
(
'load
checkpoint'
).
start
()
timers
(
'load
-
checkpoint'
).
start
()
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'load
checkpoint'
).
stop
()
timers
.
log
([
'load
checkpoint'
])
timers
(
'load
-
checkpoint'
).
stop
()
timers
.
log
([
'load
-
checkpoint'
])
else
:
args
.
iteration
=
0
# We only support local DDP with multiple micro-batches.
if
get_num_microbatches
()
>
1
:
assert
args
.
DDP_impl
==
'local'
if
len
(
model
)
==
1
:
assert
args
.
DDP_impl
==
'local'
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
assert
args
.
DDP_impl
==
'local'
# get model without FP16 and/or TorchDDP wrappers
unwrapped_model
=
model
while
hasattr
(
unwrapped_model
,
'module'
):
unwrapped_model
=
unwrapped_model
.
module
if
args
.
iteration
==
0
and
hasattr
(
unwrapped_model
,
'init_state_dict_from_bert'
):
print_rank_0
(
"Initializing ICT from pretrained BERT model"
)
unwrapped_model
.
init_state_dict_from_bert
()
if
args
.
fp16
:
optimizer
.
reload_model_params
()
model
=
unwrap_model
(
model
)
for
module
in
model
:
if
args
.
iteration
==
0
and
hasattr
(
module
,
'init_state_dict_from_bert'
):
print
(
"Initializing ICT from pretrained BERT model"
,
flush
=
True
)
module
.
init_state_dict_from_bert
()
if
args
.
fp16
:
optimizer
.
reload_model_params
()
return
model
,
optimizer
,
lr_scheduler
def
communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_forward
,
recv_backward
):
"""Communicate tensors between stages."""
args
=
get_args
()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
dtype
=
args
.
params_dtype
if
args
.
fp32_residual_connection
:
dtype
=
torch
.
float
if
recv_forward
:
tensor_recv_prev
=
torch
.
empty
(
tensor_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
if
recv_backward
:
tensor_recv_next
=
torch
.
empty
(
tensor_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
# Send tensors in both the forward and backward directions as appropriate.
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
return
tensor_recv_prev
,
tensor_recv_next
def
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
"""Backward step."""
args
=
get_args
()
timers
=
get_timers
()
# Retain the grad on the input_tensor.
if
input_tensor
is
not
None
:
input_tensor
.
retain_grad
()
# Backward pass.
if
output_tensor_grad
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
# Collect the grad of the input_tensor.
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
return
input_tensor_grad
def
forward_step_with_communication
(
forward_step_func
,
data_iterator
,
model
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
):
args
=
get_args
()
if
not
mpu
.
is_pipeline_first_stage
():
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
timers
(
'forward-recv'
).
stop
()
else
:
input_tensor
=
None
# Forward model for one step.
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
timers
(
'forward-compute'
).
stop
()
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
else
:
timers
(
'forward-send'
).
start
()
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
timers
(
'forward-send'
).
stop
()
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
def
backward_step_with_communication
(
optimizer
,
model
,
input_tensors
,
output_tensors
,
timers
):
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
timers
(
'backward-recv'
).
start
()
_
,
output_tensor_grad
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
True
)
timers
(
'backward-recv'
).
stop
()
# Backward pass for one step.
timers
(
'backward-compute'
).
start
()
input_grad_tensor
=
\
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
timers
(
'backward-compute'
).
stop
()
if
not
mpu
.
is_pipeline_first_stage
():
timers
(
'backward-send'
).
start
()
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_grad_tensor
,
recv_forward
=
False
,
recv_backward
=
False
)
timers
(
'backward-send'
).
stop
()
def
forward_and_backward_steps_with_communication
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
input_tensor
,
last_microbatch
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
):
args
=
get_args
()
# Forward model for one step.
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
timers
(
'forward-compute'
).
stop
()
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
output_tensor_grad
=
None
losses_reduced
.
append
(
loss_reduced
)
else
:
timers
(
'forward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
True
)
timers
(
'forward-send-backward-recv'
).
stop
()
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
# Backward pass for one step.
timers
(
'backward-compute'
).
start
()
input_grad_tensor
=
\
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
timers
(
'backward-compute'
).
stop
()
if
not
mpu
.
is_pipeline_first_stage
():
timers
(
'backward-send-forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_grad_tensor
,
recv_forward
=
(
not
last_microbatch
),
recv_backward
=
False
)
timers
(
'backward-send-forward-recv'
).
stop
()
else
:
input_tensor
=
None
return
input_tensor
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
):
"""Run forward and backward passes without inter-stage communication."""
args
=
get_args
()
losses_reduced
=
[]
for
i
in
range
(
get_num_microbatches
()):
timers
(
'forward-compute'
).
start
()
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
=
None
)
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
timers
(
'forward-compute'
).
stop
()
timers
(
'backward-compute'
).
start
()
output_tensor_grad
=
None
backward_step
(
optimizer
,
model
,
input_tensor
=
None
,
output_tensor
=
output_tensor
,
output_tensor_grad
=
None
)
timers
(
'backward-compute'
).
stop
()
return
losses_reduced
def
forward_backward_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
args
=
get_args
()
# Compute number of warmup microbatches.
num_microbatches
=
get_num_microbatches
()
num_warmup_microbatches
=
\
(
mpu
.
get_pipeline_model_parallel_world_size
()
-
mpu
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
forward_step_with_communication
(
forward_step_func
,
data_iterator
,
model
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
timers
(
'forward-recv'
).
stop
()
# Run 1F1B.
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
input_tensor
=
\
forward_and_backward_steps_with_communication
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
input_tensor
,
last_iteration
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
)
# Run cooldown backward passes.
for
i
in
range
(
num_warmup_microbatches
):
backward_step_with_communication
(
optimizer
,
model
,
input_tensors
,
output_tensors
,
timers
)
return
losses_reduced
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr_scheduler
):
"""Single training step."""
...
...
@@ -591,29 +339,43 @@ def train_step(forward_step_func, data_iterator,
optimizer
.
zero_grad
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
losses_reduced
=
forward_backward_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
)
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
losses_reduced
=
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
)
forward_backward_func
=
forward_backward_no_pipelining
losses_reduced
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
=
False
)
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
model
.
allreduce_params
(
reduce_after
=
False
,
fp32_allreduce
=
args
.
fp32_allreduce
)
for
model_module
in
model
:
model_module
.
allreduce_params
(
reduce_after
=
False
,
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'backward-params-all-reduce'
).
stop
()
# Barrier to measure backward stall.
timers
(
'backward-pipeline-stall'
).
start
()
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_pipeline_model_parallel_group
())
timers
(
'backward-pipeline-stall'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
if
(
mpu
.
is_pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
())
and
\
if
(
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
or
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
))
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
unwrapped_model
=
model
while
isinstance
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16Module
)):
unwrapped_model
=
unwrapped_model
.
module
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16Module
))
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
...
...
@@ -623,11 +385,11 @@ def train_step(forward_step_func, data_iterator,
# Update parameters.
timers
(
'optimizer'
).
start
()
update_successful
l
,
grad_norm
=
optimizer
.
step
()
update_successful
,
grad_norm
=
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
# Update learning rate.
if
update_successful
l
:
if
update_successful
:
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
...
...
@@ -636,7 +398,7 @@ def train_step(forward_step_func, data_iterator,
else
:
skipped_iter
=
1
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Average loss across microbatches.
loss_reduced
=
{}
for
key
in
losses_reduced
[
0
]:
...
...
@@ -690,13 +452,16 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if
name
in
timers
.
timers
:
timers_to_log
.
append
(
name
)
add_to_logging
(
'forward-compute'
)
add_to_logging
(
'forward-pipeline-stall'
)
add_to_logging
(
'forward-recv'
)
add_to_logging
(
'forward-send'
)
add_to_logging
(
'forward-
sen
d-backward-recv'
)
add_to_logging
(
'forward-
backward-send-forwar
d-backward-recv'
)
add_to_logging
(
'backward-compute'
)
add_to_logging
(
'backward-pipeline-stall'
)
add_to_logging
(
'backward-recv'
)
add_to_logging
(
'backward-send'
)
add_to_logging
(
'backward-send-forward-recv'
)
add_to_logging
(
'backward-send-backward-recv'
)
add_to_logging
(
'backward-params-all-reduce'
)
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'optimizer-copy-to-main-grad'
)
...
...
@@ -745,7 +510,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
normalizer
=
total_iterations
)
if
iteration
%
args
.
log_interval
==
0
:
elapsed_time
=
timers
(
'interval
time'
).
elapsed
()
elapsed_time
=
timers
(
'interval
-
time'
).
elapsed
()
elapsed_time_per_iteration
=
elapsed_time
/
total_iterations
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
if
args
.
log_timers_to_tensorboard
:
...
...
@@ -794,11 +559,11 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
# Extra barrier is added to make sure
# all ranks report the max time.
torch
.
distributed
.
barrier
()
timers
(
'save
checkpoint'
).
start
()
timers
(
'save
-
checkpoint'
).
start
()
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'save
checkpoint'
).
stop
()
timers
.
log
([
'save
checkpoint'
])
timers
(
'save
-
checkpoint'
).
stop
()
timers
.
log
([
'save
-
checkpoint'
])
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
...
...
@@ -811,7 +576,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
write_args_to_tensorboard
()
# Turn on training mode which enables dropout.
model
.
train
()
for
model_module
in
model
:
model_module
.
train
()
# Tracking loss.
total_loss_dict
=
{}
...
...
@@ -819,7 +585,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
iteration
=
args
.
iteration
timers
(
'interval
time'
).
start
()
timers
(
'interval
-
time'
).
start
()
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
...
...
@@ -900,7 +666,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args
=
get_args
()
# Turn on evaluation mode which disables dropout.
model
.
eval
()
for
model_module
in
model
:
model_module
.
eval
()
total_loss_dict
=
{}
...
...
@@ -912,37 +679,30 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
args
.
eval_iters
))
for
_
in
range
(
get_num_microbatches
()):
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
else
:
input_tensor
=
None
# Forward evaluation.
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
mpu
.
is_pipeline_last_stage
():
_
,
loss_dict
=
output_tensor
# Reduce across processes.
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
forward_backward_func
=
forward_backward_no_pipelining
loss_dicts
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Reduce across processes.
for
loss_dict
in
loss_dicts
:
for
key
in
loss_dict
:
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
\
loss_dict
[
key
]
else
:
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
loss_dict
[
key
]
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
micro_batch_size
\
*
get_num_microbatches
()
# Move model back to the train mode.
model
.
train
()
for
model_module
in
model
:
model_module
.
train
()
for
key
in
total_loss_dict
:
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get_num_microbatches
()
...
...
megatron/utils.py
View file @
b2f57fc4
...
...
@@ -18,6 +18,7 @@
import
sys
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
torchDDP
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
...
...
@@ -26,11 +27,25 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
get_adlr_autoresume
from
megatron
import
mpu
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
def
unwrap_model
(
model
,
module_instances
=
(
torchDDP
)):
return_list
=
True
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
return_list
=
False
unwrapped_model
=
[]
for
model_module
in
model
:
while
isinstance
(
model_module
,
module_instances
):
model_module
=
model_module
.
module
unwrapped_model
.
append
(
model_module
)
if
not
return_list
:
return
unwrapped_model
[
0
]
return
unwrapped_model
def
calc_params_l2_norm
(
model
):
"""Calculate l2 norm of parameters """
# Remove duplicate params.
...
...
@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration):
def
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
"""Check for autoresume signal and exit if it is received."""
from
megatron.checkpointing
import
save_checkpoint
args
=
get_args
()
autoresume
=
get_adlr_autoresume
()
# Add barrier to ensure consistnecy.
...
...
pretrain_bert.py
View file @
b2f57fc4
...
...
@@ -38,7 +38,7 @@ def model_provider():
args
=
get_args
()
num_tokentypes
=
2
if
args
.
bert_binary_head
else
0
i
f
m
pu
.
get_pipeline_model_parallel_world_size
()
>
1
:
de
f
m
odel_provider_pipelined
()
:
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
model
=
BertModelFirstStage
(
...
...
@@ -51,6 +51,17 @@ def model_provider():
else
:
model
=
BertModelIntermediateStage
(
num_tokentypes
=
num_tokentypes
)
return
model
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
.
append
(
model_provider_pipelined
())
else
:
model
=
model_provider_pipelined
()
else
:
model
=
BertModel
(
num_tokentypes
=
num_tokentypes
,
...
...
@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch.
timers
(
'batch-generator'
).
start
()
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
\
=
get_batch
(
data_iterator
)
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
=
get_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
if
not
args
.
bert_binary_head
:
...
...
pretrain_gpt.py
View file @
b2f57fc4
...
...
@@ -35,8 +35,8 @@ def model_provider():
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
args
=
get_args
()
i
f
m
pu
.
get_pipeline_model_parallel_world_size
()
>
1
:
de
f
m
odel_provider_pipelined
()
:
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
model
=
GPTModelFirstStage
(
num_tokentypes
=
0
)
...
...
@@ -46,6 +46,17 @@ def model_provider():
else
:
model
=
GPTModelIntermediateStage
(
num_tokentypes
=
0
)
return
model
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
.
append
(
model_provider_pipelined
())
else
:
model
=
model_provider_pipelined
()
else
:
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
)
...
...
tools/generate_samples_gpt.py
View file @
b2f57fc4
...
...
@@ -92,7 +92,9 @@ def main():
"""Main program."""
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
'no_load_rng'
:
True
,
'no_load_optim'
:
True
})
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
...
...
tools/merge_mp_partitions.py
View file @
b2f57fc4
...
...
@@ -200,6 +200,8 @@ def main():
'micro_batch_size'
:
1
,
'no_load_optim'
:
True
,
'no_load_rng'
:
True
,
'no_save_optim'
:
True
,
'no_save_rng'
:
True
,
'save_interval'
:
1
})
args
=
get_args
()
...
...
@@ -240,6 +242,11 @@ def main():
tokenizer
=
rebuild_tokenizer
(
args
)
mpu
.
initialize
.
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
for
rank
in
range
(
args
.
tensor_model_parallel_size
):
# Reset these since load_checkpoint asserts they are 0, but we are loading
# multiple checkpoints in the same process and they get set each time
args
.
consumed_train_samples
=
0
args
.
consumed_valid_samples
=
0
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
rank
)
checkpoint_name
,
iteration
=
get_parallel_checkpoint_name
(
args
.
load
)
model_
=
get_model
(
model_type
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment