Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ee7b19e7
Commit
ee7b19e7
authored
Apr 02, 2021
by
Mostofa Patwary
Browse files
Merge branch 'main' into main_dedup
parents
d413bd5f
f2d64c00
Changes
49
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1916 additions
and
780 deletions
+1916
-780
megatron/indexer.py
megatron/indexer.py
+79
-43
megatron/initialize.py
megatron/initialize.py
+73
-13
megatron/model/__init__.py
megatron/model/__init__.py
+2
-18
megatron/model/bert_model.py
megatron/model/bert_model.py
+1
-2
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+295
-0
megatron/model/distributed.py
megatron/model/distributed.py
+178
-72
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+28
-117
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+13
-5
megatron/model/module.py
megatron/model/module.py
+28
-14
megatron/model/transformer.py
megatron/model/transformer.py
+28
-5
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+3
-0
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+46
-3
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+49
-29
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+28
-0
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+189
-83
megatron/p2p_communication.py
megatron/p2p_communication.py
+264
-0
megatron/schedules.py
megatron/schedules.py
+435
-0
megatron/training.py
megatron/training.py
+133
-367
megatron/utils.py
megatron/utils.py
+30
-6
pretrain_bert.py
pretrain_bert.py
+14
-3
No files found.
megatron/indexer.py
View file @
ee7b19e7
import
sys
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
load_
ict
_checkpoint
from
megatron.checkpointing
import
load_
biencoder
_checkpoint
from
megatron.data.
ict
_dataset
import
get_
ict
_dataset
from
megatron.data.
orqa_wiki
_dataset
import
get_
open_retrieval_wiki
_dataset
from
megatron.data.
realm
_dataset
_utils
import
get_o
ne_epoch_dataloader
from
megatron.data.
orqa_wiki
_dataset
import
get_o
pen_retrieval_batch
from
megatron.data.
realm_index
import
d
et
ach
,
BlockData
from
megatron.data.
biencoder_dataset_utils
import
g
et
_one_epoch_dataloader
from
megatron.data.realm_
dataset_utils
import
get_ict_batch
from
megatron.data.realm_
index
import
detach
,
OpenRetreivalDataStore
from
megatron.model.
realm
_model
import
general_ict
_model_provider
from
megatron.model.
biencoder
_model
import
biencoder
_model_provider
from
megatron.training
import
get_model
from
megatron.training
import
get_model
class
IndexBuilder
(
object
):
class
IndexBuilder
(
object
):
"""Object for taking one pass over a dataset and creating a BlockData of its embeddings"""
"""
Object for taking one pass over a dataset and creating a BlockData of its
embeddings
"""
def
__init__
(
self
):
def
__init__
(
self
):
args
=
get_args
()
args
=
get_args
()
self
.
model
=
None
self
.
model
=
None
self
.
dataloader
=
None
self
.
dataloader
=
None
self
.
block_data
=
None
self
.
evidence_embedder_obj
=
None
self
.
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
# need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert
not
(
args
.
load
and
args
.
ict_load
)
assert
not
(
args
.
load
and
args
.
ict_load
)
self
.
using_realm_chkpt
=
args
.
ict_load
is
None
#
self.using_realm_chkpt = args.ict_load is None
self
.
log_interval
=
args
.
indexer_log_interval
self
.
log_interval
=
args
.
indexer_log_interval
self
.
batch_size
=
args
.
indexer_batch_size
self
.
batch_size
=
args
.
indexer_batch_size
...
@@ -33,59 +40,88 @@ class IndexBuilder(object):
...
@@ -33,59 +40,88 @@ class IndexBuilder(object):
self
.
iteration
=
self
.
total_processed
=
0
self
.
iteration
=
self
.
total_processed
=
0
def
load_attributes
(
self
):
def
load_attributes
(
self
):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
"""
model
=
get_model
(
lambda
:
general_ict_model_provider
(
only_block_model
=
True
))
Load the necessary attributes: model, dataloader and empty BlockData
self
.
model
=
load_ict_checkpoint
(
model
,
only_block_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
"""
self
.
model
.
eval
()
only_context_model
=
True
self
.
dataset
=
get_ict_dataset
()
if
self
.
biencoder_shared_query_context_model
:
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
self
.
batch_size
))
only_context_model
=
False
self
.
block_data
=
BlockData
(
load_from_path
=
False
)
model
=
get_model
(
lambda
:
biencoder_model_provider
(
only_context_model
\
=
only_context_model
,
biencoder_shared_query_context_model
=
\
self
.
biencoder_shared_query_context_model
))
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_context_model
=
only_context_model
)
assert
len
(
self
.
model
)
==
1
self
.
model
[
0
].
eval
()
self
.
dataset
=
get_open_retrieval_wiki_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
\
self
.
batch_size
))
self
.
evidence_embedder_obj
=
OpenRetreivalDataStore
(
\
load_from_path
=
False
)
def
track_and_report_progress
(
self
,
batch_size
):
def
track_and_report_progress
(
self
,
batch_size
):
"""Utility function for tracking progress"""
"""
Utility function for tracking progress
"""
self
.
iteration
+=
1
self
.
iteration
+=
1
self
.
total_processed
+=
batch_size
*
self
.
num_total_builders
self
.
total_processed
+=
batch_size
*
self
.
num_total_builders
if
self
.
is_main_builder
and
self
.
iteration
%
self
.
log_interval
==
0
:
if
self
.
is_main_builder
and
self
.
iteration
%
self
.
log_interval
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
def
build_and_save_index
(
self
):
def
build_and_save_index
(
self
):
"""Goes through one epoch of the dataloader and adds all data to this instance's BlockData.
"""
Goes through one epoch of the dataloader and adds all data to this
instance's BlockData.
The copy of BlockData is saved as a shard, which when run in a distributed setting will be
The copy of BlockData is saved as a shard, which when run in a
consolidated by the rank 0 process and saved as a final pickled BlockData.
distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
"""
"""
assert
len
(
self
.
model
)
==
1
unwrapped_model
=
self
.
model
[
0
]
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
unwrapped_model
=
unwrapped_model
.
module
while
True
:
while
True
:
try
:
try
:
# batch also has query_tokens and query_pad_data
# batch also has query_tokens and query_pad_data
_
,
_
,
block_tokens
,
block_pad_mask
,
block_sample_data
=
get_ict_batch
(
self
.
dataloader
)
row_id
,
context_tokens
,
context_mask
,
context_types
,
\
context_pad_mask
=
get_open_retrieval_batch
(
\
self
.
dataloader
)
except
(
StopIteration
,
IndexError
):
except
(
StopIteration
,
IndexError
):
break
break
unwrapped_model
=
self
.
model
# TODO: can we add with torch.no_grad() to reduce memory usage
while
not
hasattr
(
unwrapped_model
,
'embed_block'
):
unwrapped_model
=
unwrapped_model
.
module
# detach, separate fields and add to BlockData
# detach, separate fields and add to BlockData
block_logits
=
detach
(
unwrapped_model
.
embed_block
(
block_tokens
,
block_pad_mask
))
assert
context_mask
.
dtype
==
torch
.
bool
detached_data
=
detach
(
block_sample_data
)
context_logits
=
unwrapped_model
.
embed_text
(
unwrapped_model
.
context_model
,
context_tokens
,
context_mask
,
# block_sample_data is a 2D array [batch x 4]
context_types
)
# with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
context_logits
=
detach
(
context_logits
)
block_indices
=
detached_data
[:,
3
]
row_id
=
detach
(
row_id
)
block_metas
=
detached_data
[:,
:
3
]
self
.
evidence_embedder_obj
.
add_block_data
(
row_id
,
context_logits
)
self
.
block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_metas
)
self
.
track_and_report_progress
(
batch_size
=
len
(
row_id
)
)
self
.
track_and_report_progress
(
batch_size
=
block_tokens
.
shape
[
0
])
# This process signals to finalize its shard and then synchronize with
#
This process signals to finalize its shard and then synchronize with
the other processes
# the other processes
self
.
block_data
.
save_shard
()
self
.
evidence_embedder_obj
.
save_shard
()
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
del
self
.
model
del
self
.
model
# rank 0 process builds the final copy
# rank 0 process builds the final copy
if
self
.
is_main_builder
:
if
self
.
is_main_builder
:
self
.
block_data
.
merge_shards_and_save
()
self
.
evidence_embedder_obj
.
merge_shards_and_save
()
# make sure that every single piece of data was embedded
# make sure that every single piece of data was embedded
assert
len
(
self
.
block_data
.
embed_data
)
==
len
(
self
.
dataset
)
assert
len
(
self
.
evidence_embedder_obj
.
embed_data
)
==
\
self
.
block_data
.
clear
()
len
(
self
.
dataset
)
self
.
evidence_embedder_obj
.
clear
()
# complete building the final copy
torch
.
distributed
.
barrier
()
megatron/initialize.py
View file @
ee7b19e7
...
@@ -17,16 +17,20 @@
...
@@ -17,16 +17,20 @@
import
random
import
random
import
os
import
os
import
time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
megatron
import
fused_kernels
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_tensorboard_writer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.global_vars
import
set_global_variables
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu
import
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
from
megatron.mpu
import
(
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
)
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
...
@@ -37,8 +41,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -37,8 +41,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
what you are doing.
what you are doing.
Returns a function to finalize distributed env initialization
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
(optionally, only when args.lazy_mpu_init == True)
"""
"""
if
not
allow_no_cuda
:
if
not
allow_no_cuda
:
# Make sure cuda is available.
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
...
@@ -66,7 +69,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -66,7 +69,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# delayed initialization of DDP-related stuff
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
# We only set basic DDP globals
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
# and return function for external DDP manager to call when it has DDP initialized
# and return function for external DDP manager
# to call when it has DDP initialized
set_tensor_model_parallel_rank
(
args
.
rank
)
set_tensor_model_parallel_rank
(
args
.
rank
)
return
finish_mpu_init
return
finish_mpu_init
else
:
else
:
...
@@ -79,16 +83,71 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -79,16 +83,71 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Autoresume.
# Autoresume.
_init_autoresume
()
_init_autoresume
()
# Compile dataset C++ code.
# Compile dependencies.
if
torch
.
distributed
.
get_rank
()
==
0
:
_compile_dependencies
()
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
# Simple barrier
torch
.
distributed
.
barrier
()
# No continuation function
# No continuation function
return
None
return
None
def
_compile_dependencies
():
args
=
get_args
()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
'> compiling dataset index builder ...'
)
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
print
(
'>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# Print a warning.
if
not
((
args
.
fp16
or
args
.
bf16
)
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
):
if
args
.
rank
==
0
:
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default'
' back to unfused kernel invocations.'
,
flush
=
True
)
# Always build on rank zero first.
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
'> compiling and loading fused kernels ...'
,
flush
=
True
)
fused_kernels
.
load
(
args
)
torch
.
distributed
.
barrier
()
else
:
torch
.
distributed
.
barrier
()
fused_kernels
.
load
(
args
)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>>> done with compiling and loading fused kernels. '
'Compilation time: {:.3f} seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
def
_initialize_distributed
():
def
_initialize_distributed
():
"""Initialize torch.distributed and mpu."""
"""Initialize torch.distributed and mpu."""
...
@@ -133,7 +192,8 @@ def _initialize_distributed():
...
@@ -133,7 +192,8 @@ def _initialize_distributed():
print
(
'model parallel is already initialized'
)
print
(
'model parallel is already initialized'
)
else
:
else
:
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
)
def
_init_autoresume
():
def
_init_autoresume
():
...
...
megatron/model/__init__.py
View file @
ee7b19e7
...
@@ -13,34 +13,18 @@
...
@@ -13,34 +13,18 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
_LAYER_NORM
=
None
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
def
import_layernorm
(
fp32_residual_connection
):
global
_LAYER_NORM
if
not
_LAYER_NORM
:
if
fp32_residual_connection
:
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
else
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
LayerNorm
_LAYER_NORM
=
LayerNorm
return
_LAYER_NORM
from
.distributed
import
*
from
.distributed
import
*
from
.bert_model
import
(
BertModel
,
from
.bert_model
import
(
BertModel
,
BertModelFirstStage
,
BertModelFirstStage
,
BertModelIntermediateStage
,
BertModelIntermediateStage
,
BertModelLastStage
)
BertModelLastStage
)
from
.realm_model
import
ICTBertModel
from
.gpt_model
import
(
GPTModel
,
from
.gpt_model
import
(
GPTModel
,
GPTModelFirstStage
,
GPTModelFirstStage
,
GPTModelIntermediateStage
,
GPTModelIntermediateStage
,
GPTModelLastStage
)
GPTModelLastStage
)
from
.language_model
import
get_language_model
from
.language_model
import
get_language_model
from
.module
import
FP16Module
from
.module
import
Float16Module
from
.realm_model
import
ICTBertModel
megatron/model/bert_model.py
View file @
ee7b19e7
...
@@ -22,7 +22,7 @@ from megatron import mpu
...
@@ -22,7 +22,7 @@ from megatron import mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
from
megatron.model
import
import_l
ayer
n
orm
from
megatron.model
import
L
ayer
N
orm
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
init_method_normal
...
@@ -78,7 +78,6 @@ class BertLMHead(MegatronModule):
...
@@ -78,7 +78,6 @@ class BertLMHead(MegatronModule):
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
if
args
.
openai_gelu
:
...
...
megatron/model/biencoder_model.py
0 → 100644
View file @
ee7b19e7
import
os
import
torch
import
sys
from
megatron
import
get_args
,
print_rank_0
from
megatron.checkpointing
import
fix_query_key_value_ordering
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
from
megatron.checkpointing
import
get_checkpoint_name
from
megatron
import
mpu
,
get_tokenizer
from
megatron.model.bert_model
import
bert_position_ids
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
get_language_model
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
.module
import
MegatronModule
def
biencoder_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
):
"""Build the model."""
args
=
get_args
()
assert
mpu
.
get_tensor_model_parallel_world_size
()
==
1
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
,
\
"Model parallel size > 1 not supported for ICT"
print_rank_0
(
'building BiEncoderModel...'
)
# simpler to just keep using 2 tokentypes since
# the LM we initialize with has 2 tokentypes
model
=
BiEncoderModel
(
num_tokentypes
=
2
,
parallel_output
=
False
,
only_query_model
=
only_query_model
,
only_context_model
=
only_context_model
,
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
)
return
model
class
BiEncoderModel
(
MegatronModule
):
"""Bert-based module for Biencoder model."""
def
__init__
(
self
,
num_tokentypes
=
1
,
parallel_output
=
True
,
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
):
super
(
BiEncoderModel
,
self
).
__init__
()
args
=
get_args
()
bert_kwargs
=
dict
(
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
self
.
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
assert
not
(
only_context_model
and
only_query_model
)
self
.
use_context_model
=
not
only_query_model
self
.
use_query_model
=
not
only_context_model
self
.
biencoder_projection_dim
=
args
.
biencoder_projection_dim
if
self
.
biencoder_shared_query_context_model
:
self
.
model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_model_key
=
'shared_model'
self
.
query_model
,
self
.
context_model
=
self
.
model
,
self
.
model
else
:
if
self
.
use_query_model
:
# this model embeds (pseudo-)queries - Embed_input in the paper
self
.
query_model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_query_key
=
'query_model'
if
self
.
use_context_model
:
# this model embeds evidence blocks - Embed_doc in the paper
self
.
context_model
=
PretrainedBertModel
(
**
bert_kwargs
)
self
.
_context_key
=
'context_model'
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
query_types
,
context_tokens
,
context_attention_mask
,
context_types
):
"""Run a forward pass for each of the models and
return the respective embeddings."""
if
self
.
use_query_model
:
query_logits
=
self
.
embed_text
(
self
.
query_model
,
query_tokens
,
query_attention_mask
,
query_types
)
else
:
raise
ValueError
(
"Cannot embed query without the query model."
)
if
self
.
use_context_model
:
context_logits
=
self
.
embed_text
(
self
.
context_model
,
context_tokens
,
context_attention_mask
,
context_types
)
else
:
raise
ValueError
(
"Cannot embed block without the block model."
)
return
query_logits
,
context_logits
@
staticmethod
def
embed_text
(
model
,
tokens
,
attention_mask
,
token_types
):
"""Embed a batch of tokens using the model"""
logits
=
model
(
tokens
,
attention_mask
,
token_types
)
return
logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
\
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
if
self
.
biencoder_shared_query_context_model
:
state_dict_
[
self
.
_model_key
]
=
\
self
.
model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
else
:
if
self
.
use_query_model
:
state_dict_
[
self
.
_query_key
]
=
\
self
.
query_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
use_context_model
:
state_dict_
[
self
.
_context_key
]
=
\
self
.
context_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
if
self
.
biencoder_shared_query_context_model
:
print_rank_0
(
"Loading shared query-context model"
)
self
.
model
.
load_state_dict
(
state_dict
[
self
.
_model_key
],
\
strict
=
strict
)
else
:
if
self
.
use_query_model
:
print_rank_0
(
"Loading query model"
)
self
.
query_model
.
load_state_dict
(
\
state_dict
[
self
.
_query_key
],
strict
=
strict
)
if
self
.
use_context_model
:
print_rank_0
(
"Loading context model"
)
self
.
context_model
.
load_state_dict
(
\
state_dict
[
self
.
_context_key
],
strict
=
strict
)
def
init_state_dict_from_bert
(
self
):
"""Initialize the state from a pretrained BERT model
on iteration zero of ICT pretraining"""
args
=
get_args
()
if
args
.
bert_load
is
None
:
print_rank_0
(
"bert-load argument is None"
)
return
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
bert_load
)
if
not
os
.
path
.
isfile
(
tracker_filename
):
raise
FileNotFoundError
(
"Could not find BERT checkpoint"
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
checkpoint_name
=
get_checkpoint_name
(
args
.
bert_load
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading BERT checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
# Load the checkpoint.
try
:
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
except
ModuleNotFoundError
:
from
megatron.fp16_deprecated
import
loss_scaler
# For backward compatibility.
print_rank_0
(
' > deserializing using the old code structure ...'
)
sys
.
modules
[
'fp16.loss_scaler'
]
=
sys
.
modules
[
'megatron.fp16_deprecated.loss_scaler'
]
sys
.
modules
[
'megatron.fp16.loss_scaler'
]
=
sys
.
modules
[
'megatron.fp16_deprecated.loss_scaler'
]
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
sys
.
modules
.
pop
(
'fp16.loss_scaler'
,
None
)
sys
.
modules
.
pop
(
'megatron.fp16.loss_scaler'
,
None
)
except
BaseException
:
print_rank_0
(
'could not load the BERT checkpoint'
)
sys
.
exit
()
checkpoint_version
=
state_dict
.
get
(
'checkpoint_version'
,
0
)
# load the LM state dict into each model
model_dict
=
state_dict
[
'model'
][
'language_model'
]
if
self
.
biencoder_shared_query_context_model
:
self
.
model
.
language_model
.
load_state_dict
(
model_dict
)
fix_query_key_value_ordering
(
self
.
model
,
checkpoint_version
)
else
:
if
self
.
use_query_model
:
self
.
query_model
.
language_model
.
load_state_dict
(
model_dict
)
# give each model the same ict_head to begin with as well
if
self
.
biencoder_projection_dim
>
0
:
query_proj_state_dict
=
\
self
.
state_dict_for_save_checkpoint
()
\
[
self
.
_query_key
][
'projection_enc'
]
fix_query_key_value_ordering
(
self
.
query_model
,
checkpoint_version
)
if
self
.
use_context_model
:
self
.
context_model
.
language_model
.
load_state_dict
(
model_dict
)
if
self
.
query_model
is
not
None
and
\
self
.
biencoder_projection_dim
>
0
:
self
.
context_model
.
projection_enc
.
load_state_dict
\
(
query_proj_state_dict
)
fix_query_key_value_ordering
(
self
.
context_model
,
checkpoint_version
)
class
PretrainedBertModel
(
MegatronModule
):
"""BERT-based encoder for queries or contexts used for
learned information retrieval."""
def
__init__
(
self
,
num_tokentypes
=
2
,
parallel_output
=
True
):
super
(
PretrainedBertModel
,
self
).
__init__
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
self
.
pad_id
=
tokenizer
.
pad
self
.
biencoder_projection_dim
=
args
.
biencoder_projection_dim
self
.
parallel_output
=
parallel_output
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
if
args
.
biencoder_projection_dim
>
0
:
self
.
projection_enc
=
get_linear_layer
(
args
.
hidden_size
,
args
.
biencoder_projection_dim
,
init_method
)
self
.
_projection_enc_key
=
'projection_enc'
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
)
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
# This mask will be used in average-pooling and max-pooling
pool_mask
=
(
input_ids
==
self
.
pad_id
).
unsqueeze
(
2
)
# Taking the representation of the [CLS] token of BERT
pooled_output
=
lm_output
[:,
0
,
:]
# Converting to float16 dtype
pooled_output
=
pooled_output
.
to
(
lm_output
.
dtype
)
# Output.
if
self
.
biencoder_projection_dim
:
pooled_output
=
self
.
projection_enc
(
pooled_output
)
return
pooled_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
biencoder_projection_dim
>
0
:
state_dict_
[
self
.
_projection_enc_key
]
=
\
self
.
projection_enc
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
print_rank_0
(
"loading BERT weights"
)
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
self
.
biencoder_projection_dim
>
0
:
print_rank_0
(
"loading projection head weights"
)
self
.
projection_enc
.
load_state_dict
(
state_dict
[
self
.
_projection_enc_key
],
strict
=
strict
)
megatron/model/distributed.py
View file @
ee7b19e7
...
@@ -13,100 +13,206 @@
...
@@ -13,100 +13,206 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
abc
import
ABC
from
abc
import
abstractmethod
import
torch
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
import
torch.distributed
as
dist
from
torch.nn.modules
import
Module
from
torch.autograd
import
Variable
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
class
DistributedDataParallel
(
MegatronModule
):
def
__init__
(
self
,
module
):
class
MemoryBuffer
:
super
(
DistributedDataParallel
,
self
).
__init__
()
self
.
warn_on_half
=
True
if
dist
.
_backend
==
dist
.
dist_backend
.
GLOO
else
False
def
__init__
(
self
,
numel
,
dtype
):
self
.
numel
=
numel
self
.
dtype
=
dtype
self
.
data
=
torch
.
zeros
(
self
.
numel
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
def
zero
(
self
):
"""Reset the buffer to zero."""
self
.
data
.
zero_
()
def
get
(
self
,
shape
,
start_index
):
"""Return a tensor with the input `shape` as a view into the
1-D data starting at `start_index`."""
end_index
=
start_index
+
shape
.
numel
()
assert
end_index
<=
self
.
numel
,
\
'requested tensor is out of the buffer range.'
buffer_tensor
=
self
.
data
[
start_index
:
end_index
]
buffer_tensor
=
buffer_tensor
.
view
(
shape
)
return
buffer_tensor
class
DistributedDataParallelBase
(
MegatronModule
,
ABC
):
"""Abstract class for DDP."""
def
__init__
(
self
,
module
):
super
(
DistributedDataParallelBase
,
self
).
__init__
()
# Keep a pointer to the model.
self
.
module
=
module
self
.
module
=
module
self
.
data_parallel_group
=
mpu
.
get_data_parallel_group
()
def
allreduce_params
(
reduce_after
=
True
,
no_scale
=
False
,
fp32_allreduce
=
False
):
@
abstractmethod
if
(
self
.
needs_reduction
):
def
allreduce_gradients
(
self
):
self
.
needs_reduction
=
False
pass
buckets
=
{}
for
name
,
param
in
self
.
module
.
named_parameters
():
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
(
param
.
data
.
type
())
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
if
self
.
warn_on_half
:
if
torch
.
cuda
.
HalfTensor
in
buckets
:
print
(
"WARNING: gloo dist backend for half parameters may be extremely slow."
+
" It is recommended to use the NCCL backend in this case."
)
self
.
warn_on_half
=
False
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
if
fp32_allreduce
:
coalesced
=
coalesced
.
float
()
if
not
no_scale
and
not
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
dist
.
all_reduce
(
coalesced
,
group
=
self
.
data_parallel_group
)
torch
.
cuda
.
synchronize
()
if
not
no_scale
and
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
self
.
hook_handles
=
[]
self
.
hooks
=
[]
for
param
in
list
(
self
.
module
.
parameters
()):
def
allreduce_hook
(
*
unused
):
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
# handle = param.register_hook(allreduce_hook)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self
.
allreduce_params
=
allreduce_params
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
self
.
needs_reduction
=
True
return
self
.
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
#[h.remove() for h in self.hook_handles]
return
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
sd
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
# for handle, hook in zip(self.hook_handles, self.hooks):
# d = handle.hooks_dict_ref()
# d[handle.id] = hook
return
sd
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
return
self
.
module
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
return
self
.
module
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
keep_vars
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
'''
def _sync_buffers(self):
buffers = list(self.module._all_buffers())
class
DistributedDataParallel
(
DistributedDataParallelBase
):
if len(buffers) > 0:
"""DDP with contiguous buffers options to storre and accumulate gradients.
# cross-node buffer sync
This class:
flat_buffers = _flatten_dense_tensors(buffers)
- has the potential to reduce memory fragmentation.
dist.broadcast(flat_buffers, 0)
- provides the option to do the gradient accumulation
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
in a type other than the params type (for example fp32)
buf.copy_(synced)
def train(self, mode=True):
Arguments:
# Clear NCCL communicator and CUDA event cache of the default group ID,
module: input model.
# These cache will be recreated at the later call. This is currently a
accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
# work-around for a potential NCCL deadlock.
and the gradient all-reduce all in in float32. If this option is
if dist._backend == dist.dist_backend.NCCL:
true, we require `use_contiguous_buffers` to be true too.
dist._clear_group_cache()
use_contiguous_buffers: if true, use a contiguous buffer to store the
super(DistributedDataParallel, self).train(mode)
gradients.
self.module.train(mode)
"""
'''
def
__init__
(
self
,
module
,
accumulate_allreduce_grads_in_fp32
,
use_contiguous_buffers
):
super
(
DistributedDataParallel
,
self
).
__init__
(
module
)
self
.
accumulate_allreduce_grads_in_fp32
\
=
accumulate_allreduce_grads_in_fp32
self
.
use_contiguous_buffers
=
use_contiguous_buffers
# If we are using fp32-accumulate-allreduce explicitly
# this means we need main grads in a continous buffer.
if
self
.
accumulate_allreduce_grads_in_fp32
:
assert
self
.
use_contiguous_buffers
# ===================================
# Rest of this part applies only to
# the case we use continuous buffers.
# ===================================
self
.
_grad_buffers
=
None
if
self
.
use_contiguous_buffers
:
self
.
_grad_buffers
=
{}
# Simple function to define buffer type.
def
_get_buffer_type
(
param
):
return
torch
.
float
if
\
self
.
accumulate_allreduce_grads_in_fp32
else
param
.
dtype
# First calculate total number of elements per type.
type_num_elements
=
{}
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
dtype
=
_get_buffer_type
(
param
)
type_num_elements
[
dtype
]
=
type_num_elements
.
get
(
dtype
,
0
)
\
+
param
.
data
.
nelement
()
# Allocate the buffer.
for
dtype
,
num_elements
in
type_num_elements
.
items
():
self
.
_grad_buffers
[
dtype
]
=
MemoryBuffer
(
num_elements
,
dtype
)
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
dtype
=
_get_buffer_type
(
param
)
type_num_elements
[
dtype
]
-=
param
.
data
.
nelement
()
param
.
main_grad
=
self
.
_grad_buffers
[
dtype
].
get
(
param
.
data
.
shape
,
type_num_elements
[
dtype
])
# Backward hook.
# Accumalation function for the gradients. We need
# to store them so they don't go out of scope.
self
.
grad_accs
=
[]
# Loop over all the parameters in the model.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
# Expand so we get access to grad_fn.
param_tmp
=
param
.
expand_as
(
param
)
# Get the gradient accumulator functtion.
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
.
register_hook
(
self
.
_make_param_hook
(
param
))
self
.
grad_accs
.
append
(
grad_acc
)
def
_make_param_hook
(
self
,
param
):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def
param_hook
(
*
unused
):
# Add the gradient to the buffer.
if
param
.
grad
.
data
is
not
None
:
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
# Now we can deallocate grad memory.
param
.
grad
=
None
return
param_hook
def
zero_grad_buffer
(
self
):
"""Set the grad buffer data to zero. Needs to be called at the
begining of each iteration."""
assert
self
.
_grad_buffers
is
not
None
,
'buffers are not initialized.'
for
_
,
buffer_
in
self
.
_grad_buffers
.
items
():
buffer_
.
zero
()
def
allreduce_gradients
(
self
):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
if
self
.
_grad_buffers
is
not
None
:
for
_
,
buffer_
in
self
.
_grad_buffers
.
items
():
buffer_
.
data
/=
mpu
.
get_data_parallel_world_size
()
torch
.
distributed
.
all_reduce
(
buffer_
.
data
,
group
=
mpu
.
get_data_parallel_group
())
else
:
# Otherwise, bucketize and all-reduce
buckets
=
{}
# Pack the buckets.
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
and
param
.
grad
is
not
None
:
tp
=
param
.
data
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
param
.
main_grad
=
param
.
grad
# For each bucket, all-reduce and copy all-reduced grads.
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
mpu
.
get_data_parallel_world_size
()
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
mpu
.
get_data_parallel_group
())
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
megatron/model/fused_layer_norm.py
View file @
ee7b19e7
...
@@ -15,29 +15,23 @@
...
@@ -15,29 +15,23 @@
"""This code is copied fron NVIDIA apex:
"""This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex
https://github.com/NVIDIA/apex
with
minor
changes. """
with
some
changes. """
import
math
import
torch
import
numbers
import
numbers
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
init
from
torch.nn
import
init
from
torch.nn
import
functional
as
F
import
importlib
import
importlib
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
None
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
None
fused_mix_prec_layer_norm_cuda
=
None
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
global
fused_mix_prec_layer_norm_cuda
if
fused_mix_prec_layer_norm_cuda
is
None
:
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
input_
=
input
.
contiguous
()
...
@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
...
@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
output
,
mean
,
invvar
=
fused_mix_prec_layer_norm_cuda
.
forward_affine
(
output
,
mean
,
invvar
=
fused_mix_prec_layer_norm_cuda
.
forward_affine
(
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
return
output
return
output
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
input_
,
weight_
,
bias_
,
mean
,
invvar
=
ctx
.
saved_tensors
input_
,
weight_
,
bias_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
,
grad_weight
,
grad_bias
=
fused_mix_prec_layer_norm_cuda
.
backward_affine
(
grad_input
,
grad_weight
,
grad_bias
\
=
fused_mix_prec_layer_norm_cuda
.
backward_affine
(
grad_output
.
contiguous
(),
mean
,
invvar
,
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
weight_
,
bias_
,
ctx
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
class
FusedLayerNormFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
def
forward
(
ctx
,
input
,
normalized_shape
,
eps
):
global
fused_layer_norm_cuda
if
fused_layer_norm_cuda
is
None
:
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward
(
input_
,
ctx
.
normalized_shape
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
mean
,
invvar
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
None
grad_input
=
fused_layer_norm_cuda
.
backward
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
ctx
.
eps
)
return
grad_input
,
None
,
None
def
fused_layer_norm_affine
(
input
,
normalized_shape
,
weight
,
bias
,
eps
=
1e-6
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
bias
,
normalized_shape
,
eps
)
def
fused_layer_norm
(
input
,
normalized_shape
,
eps
=
1e-6
):
return
FusedLayerNormFunction
.
apply
(
input
,
normalized_shape
,
eps
)
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
r
"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
):
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedLayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedLayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
eps
=
eps
self
.
eps
=
eps
self
.
elementwise_affine
=
elementwise_affine
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
if
self
.
elementwise_affine
:
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
else
:
self
.
register_parameter
(
'weight'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
if
self
.
elementwise_affine
:
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
if
not
input
.
is_cuda
:
return
F
.
layer_norm
(
def
forward
(
self
,
input
):
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
return
FusedLayerNormAffineFunction
.
apply
(
if
self
.
elementwise_affine
:
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
return
FusedLayerNormFunction
.
apply
(
input
,
self
.
normalized_shape
,
self
.
eps
)
def
extra_repr
(
self
):
return
'{normalized_shape}, eps={eps}, '
\
'elementwise_affine={elementwise_affine}'
.
format
(
**
self
.
__dict__
)
megatron/model/fused_softmax.py
View file @
ee7b19e7
...
@@ -96,6 +96,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -96,6 +96,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
input_in_fp16
,
input_in_fp16
,
input_in_bf16
,
attn_mask_type
,
attn_mask_type
,
scaled_masked_softmax_fusion
,
scaled_masked_softmax_fusion
,
mask_func
,
mask_func
,
...
@@ -104,6 +105,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -104,6 +105,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
):
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
\
'both fp16 and bf16 flags cannot be active at the same time.'
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
mask_func
=
mask_func
self
.
mask_func
=
mask_func
...
@@ -128,8 +133,8 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -128,8 +133,8 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
# invoke custom kernel
if
self
.
input_in_f
p
16
and
mask
is
not
None
and
\
if
self
.
input_in_f
loat
16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
...
@@ -142,7 +147,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -142,7 +147,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
else
:
if
self
.
input_in_f
p
16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_f
loat
16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
if
self
.
scale
is
not
None
:
...
@@ -150,7 +155,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -150,7 +155,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
probs
=
probs
.
half
()
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
return
probs
megatron/model/module.py
View file @
ee7b19e7
...
@@ -25,6 +25,7 @@ from megatron import mpu
...
@@ -25,6 +25,7 @@ from megatron import mpu
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
_FLOAT_TYPES
=
(
torch
.
FloatTensor
,
torch
.
cuda
.
FloatTensor
)
_HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
_HALF_TYPES
=
(
torch
.
HalfTensor
,
torch
.
cuda
.
HalfTensor
)
_BF16_TYPES
=
(
torch
.
BFloat16Tensor
,
torch
.
cuda
.
BFloat16Tensor
)
...
@@ -50,9 +51,9 @@ class MegatronModule(torch.nn.Module):
...
@@ -50,9 +51,9 @@ class MegatronModule(torch.nn.Module):
def
word_embeddings_weight
(
self
):
def
word_embeddings_weight
(
self
):
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
not
self
.
share_word_embeddings
:
if
not
self
.
share_word_embeddings
:
raise
Exception
(
'word_embeddings_weight() called for last '
raise
Exception
(
'word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false'
)
'stage, but share_word_embeddings is false'
)
...
@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module):
...
@@ -109,6 +110,7 @@ class MegatronModule(torch.nn.Module):
"this needs to be handled manually. If you are training "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
"something is definitely wrong."
)
def
conversion_helper
(
val
,
conversion
):
def
conversion_helper
(
val
,
conversion
):
"""Apply conversion to val. Recursively apply conversion if `val`
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
#is a nested tuple/list structure."""
...
@@ -120,44 +122,56 @@ def conversion_helper(val, conversion):
...
@@ -120,44 +122,56 @@ def conversion_helper(val, conversion):
return
rtn
return
rtn
def
fp32_to_f
p
16
(
val
):
def
fp32_to_f
loat
16
(
val
,
float16_convertor
):
"""Convert fp32 `val` to fp16"""
"""Convert fp32 `val` to fp16
/bf16
"""
def
half_conversion
(
val
):
def
half_conversion
(
val
):
val_typecheck
=
val
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
_FLOAT_TYPES
):
if
isinstance
(
val_typecheck
,
_FLOAT_TYPES
):
val
=
val
.
half
(
)
val
=
float16_convertor
(
val
)
return
val
return
val
return
conversion_helper
(
val
,
half_conversion
)
return
conversion_helper
(
val
,
half_conversion
)
def
f
p
16_to_fp32
(
val
):
def
f
loat
16_to_fp32
(
val
):
"""Convert fp16 `val` to fp32"""
"""Convert fp16
/bf16
`val` to fp32"""
def
float_conversion
(
val
):
def
float_conversion
(
val
):
val_typecheck
=
val
val_typecheck
=
val
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
if
isinstance
(
val_typecheck
,
(
Parameter
,
Variable
)):
val_typecheck
=
val
.
data
val_typecheck
=
val
.
data
if
isinstance
(
val_typecheck
,
_HALF_TYPES
):
if
isinstance
(
val_typecheck
,
(
_BF16_TYPES
,
_HALF_TYPES
)
)
:
val
=
val
.
float
()
val
=
val
.
float
()
return
val
return
val
return
conversion_helper
(
val
,
float_conversion
)
return
conversion_helper
(
val
,
float_conversion
)
class
FP16Module
(
MegatronModule
):
class
Float16Module
(
MegatronModule
):
def
__init__
(
self
,
module
,
args
):
super
(
Float16Module
,
self
).
__init__
()
if
args
.
fp16
:
self
.
add_module
(
'module'
,
module
.
half
())
def
float16_convertor
(
val
):
return
val
.
half
()
elif
args
.
bf16
:
self
.
add_module
(
'module'
,
module
.
bfloat16
())
def
float16_convertor
(
val
):
return
val
.
bfloat16
()
else
:
raise
Exception
(
'should not be here'
)
def
__init__
(
self
,
module
):
self
.
float16_convertor
=
float16_convertor
super
(
FP16Module
,
self
).
__init__
()
self
.
add_module
(
'module'
,
module
.
half
())
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
inputs
=
fp32_to_f
p
16
(
inputs
)
inputs
=
fp32_to_f
loat
16
(
inputs
,
self
.
float16_convertor
)
outputs
=
self
.
module
(
*
inputs
,
**
kwargs
)
outputs
=
self
.
module
(
*
inputs
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
outputs
=
f
p
16_to_fp32
(
outputs
)
outputs
=
f
loat
16_to_fp32
(
outputs
)
return
outputs
return
outputs
...
...
megatron/model/transformer.py
View file @
ee7b19e7
...
@@ -22,7 +22,7 @@ from megatron import get_args
...
@@ -22,7 +22,7 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
import_l
ayer
n
orm
from
megatron.model
import
L
ayer
N
orm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
...
@@ -116,6 +116,7 @@ class ParallelAttention(MegatronModule):
...
@@ -116,6 +116,7 @@ class ParallelAttention(MegatronModule):
super
(
ParallelAttention
,
self
).
__init__
()
super
(
ParallelAttention
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
fp16
=
args
.
fp16
self
.
bf16
=
args
.
bf16
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
...
@@ -164,7 +165,7 @@ class ParallelAttention(MegatronModule):
...
@@ -164,7 +165,7 @@ class ParallelAttention(MegatronModule):
self
.
norm_factor
*=
coeff
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
fp16
,
self
.
bf16
,
self
.
attn_mask_type
,
self
.
attn_mask_type
,
args
.
masked_softmax_fusion
,
args
.
masked_softmax_fusion
,
attention_mask_func
,
attention_mask_func
,
...
@@ -397,8 +398,10 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -397,8 +398,10 @@ class ParallelTransformerLayer(MegatronModule):
self
.
apply_residual_connection_post_layernorm
\
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
=
args
.
apply_residual_connection_post_layernorm
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
# Layernorm on the input data.
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
self
.
input_layernorm
=
LayerNorm
(
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
...
@@ -533,6 +536,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -533,6 +536,7 @@ class ParallelTransformer(MegatronModule):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
...
@@ -552,13 +556,32 @@ class ParallelTransformer(MegatronModule):
...
@@ -552,13 +556,32 @@ class ParallelTransformer(MegatronModule):
layer_number
,
layer_number
,
layer_type
=
layer_type
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
self_attn_mask_type
=
self_attn_mask_type
)
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
# Final layer norm before output.
# Final layer norm before output.
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
...
...
megatron/mpu/__init__.py
View file @
ee7b19e7
...
@@ -38,6 +38,7 @@ from .initialize import get_pipeline_model_parallel_next_rank
...
@@ -38,6 +38,7 @@ from .initialize import get_pipeline_model_parallel_next_rank
from
.initialize
import
get_pipeline_model_parallel_prev_rank
from
.initialize
import
get_pipeline_model_parallel_prev_rank
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_world_size
,
set_tensor_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
get_pipeline_model_parallel_world_size
,
set_pipeline_model_parallel_world_size
from
.initialize
import
get_virtual_pipeline_model_parallel_rank
,
set_virtual_pipeline_model_parallel_rank
from
.initialize
import
initialize_model_parallel
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.initialize
import
model_parallel_is_initialized
...
@@ -58,6 +59,8 @@ from .random import get_cuda_rng_tracker
...
@@ -58,6 +59,8 @@ from .random import get_cuda_rng_tracker
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
init_checkpointed_activations_memory_buffer
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
reset_checkpointed_activations_memory_buffer
from
.random
import
gather_split_1d_tensor
from
.random
import
split_tensor_into_1d_equal_chunks
from
.utils
import
divide
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
split_tensor_along_last_dim
megatron/mpu/initialize.py
View file @
ee7b19e7
...
@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None
...
@@ -32,6 +32,9 @@ _EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
# These values enable us to change the mpu sizes on the fly.
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
...
@@ -48,7 +51,8 @@ def is_unitialized():
...
@@ -48,7 +51,8 @@ def is_unitialized():
def
initialize_model_parallel
(
tensor_model_parallel_size_
=
1
,
def
initialize_model_parallel
(
tensor_model_parallel_size_
=
1
,
pipeline_model_parallel_size_
=
1
):
pipeline_model_parallel_size_
=
1
,
virtual_pipeline_model_parallel_size_
=
None
):
"""
"""
Initialize model data parallel groups.
Initialize model data parallel groups.
...
@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -91,6 +95,12 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
num_data_parallel_groups
=
world_size
//
data_parallel_size
if
virtual_pipeline_model_parallel_size_
is
not
None
:
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
virtual_pipeline_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
# Build the data-parallel groups.
# Build the data-parallel groups.
...
@@ -258,17 +268,46 @@ def get_pipeline_model_parallel_rank():
...
@@ -258,17 +268,46 @@ def get_pipeline_model_parallel_rank():
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
def
is_pipeline_first_stage
():
def
is_pipeline_first_stage
(
ignore_virtual
=
False
):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
if
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
and
\
get_virtual_pipeline_model_parallel_rank
()
!=
0
:
return
False
return
get_pipeline_model_parallel_rank
()
==
0
return
get_pipeline_model_parallel_rank
()
==
0
def
is_pipeline_last_stage
():
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
virtual_pipeline_model_parallel_world_size
=
\
get_virtual_pipeline_model_parallel_world_size
()
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
\
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
):
return
False
return
get_pipeline_model_parallel_rank
()
==
(
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
get_pipeline_model_parallel_world_size
()
-
1
)
def
get_virtual_pipeline_model_parallel_rank
():
"""Return the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def
set_virtual_pipeline_model_parallel_rank
(
rank
):
"""Set the virtual pipeline-parallel rank."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
rank
def
get_virtual_pipeline_model_parallel_world_size
():
"""Return the virtual pipeline-parallel world size."""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
def
get_tensor_model_parallel_src_rank
():
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
in the tensor model parallel group."""
...
@@ -276,11 +315,13 @@ def get_tensor_model_parallel_src_rank():
...
@@ -276,11 +315,13 @@ def get_tensor_model_parallel_src_rank():
local_world_size
=
get_tensor_model_parallel_world_size
()
local_world_size
=
get_tensor_model_parallel_world_size
()
return
(
global_rank
//
local_world_size
)
*
local_world_size
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_pipeline_model_parallel_first_rank
():
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
def
get_pipeline_model_parallel_last_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
"Pipeline parallel group is not initialized"
...
@@ -294,6 +335,7 @@ def get_pipeline_model_parallel_next_rank():
...
@@ -294,6 +335,7 @@ def get_pipeline_model_parallel_next_rank():
world_size
=
get_pipeline_model_parallel_world_size
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
def
get_pipeline_model_parallel_prev_rank
():
def
get_pipeline_model_parallel_prev_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
"Pipeline parallel group is not initialized"
...
@@ -301,6 +343,7 @@ def get_pipeline_model_parallel_prev_rank():
...
@@ -301,6 +343,7 @@ def get_pipeline_model_parallel_prev_rank():
world_size
=
get_pipeline_model_parallel_world_size
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
def
get_data_parallel_world_size
():
def
get_data_parallel_world_size
():
"""Return world size for the data parallel group."""
"""Return world size for the data parallel group."""
return
torch
.
distributed
.
get_world_size
(
group
=
get_data_parallel_group
())
return
torch
.
distributed
.
get_world_size
(
group
=
get_data_parallel_group
())
...
...
megatron/optimizer/__init__.py
View file @
ee7b19e7
...
@@ -17,33 +17,32 @@ from apex.optimizers import FusedAdam as Adam
...
@@ -17,33 +17,32 @@ from apex.optimizers import FusedAdam as Adam
from
apex.optimizers
import
FusedSGD
as
SGD
from
apex.optimizers
import
FusedSGD
as
SGD
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model
import
import_l
ayer
n
orm
from
megatron.model
import
L
ayer
N
orm
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
F
P
16OptimizerWithF
P
16Params
,
FP32Optimizer
from
.optimizer
import
F
loat
16OptimizerWithF
loat
16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
module
):
def
_get_params_for_weight_decay_optimization
(
module
s
):
"""Divide params into with-weight-decay and without-weight-decay groups.
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
Layernorms and baises will have no weight decay but the rest will.
"""
"""
args
=
get_args
()
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
weight_decay_params
=
{
'params'
:
[]}
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module_
in
module
.
modules
():
for
module
in
modules
:
if
isinstance
(
module_
,
LayerNorm
):
for
module_
in
module
.
modules
():
no_weight_decay_params
[
'params'
].
extend
(
if
isinstance
(
module_
,
LayerNorm
):
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
no_weight_decay_params
[
'params'
].
extend
(
if
p
is
not
None
])
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
else
:
if
p
is
not
None
])
weight_decay_params
[
'params'
].
extend
(
else
:
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
weight_decay_params
[
'params'
].
extend
(
if
p
is
not
None
and
n
!=
'bias'
])
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
no_weight_decay_params
[
'params'
].
extend
(
if
p
is
not
None
and
n
!=
'bias'
])
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
no_weight_decay_params
[
'params'
].
extend
(
if
p
is
not
None
and
n
==
'bias'
])
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
return
weight_decay_params
,
no_weight_decay_params
...
@@ -66,24 +65,45 @@ def get_megatron_optimizer(model):
...
@@ -66,24 +65,45 @@ def get_megatron_optimizer(model):
momentum
=
args
.
sgd_momentum
)
momentum
=
args
.
sgd_momentum
)
else
:
else
:
raise
Exception
(
'{} optimizer is not supported.'
.
format
(
raise
Exception
(
'{} optimizer is not supported.'
.
format
(
args
.
optimizer
))
args
.
optimizer
))
# Determine whether the params have main-grad field.
params_have_main_grad
=
False
if
args
.
DDP_impl
==
'local'
:
params_have_main_grad
=
True
if
args
.
fp16
:
if
args
.
fp16
or
args
.
bf16
:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler
=
None
# Constant loss scale.
# Constant loss scale.
if
args
.
loss_scale
:
if
args
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
# Dynamic loss scale.
# Dynamic loss scale.
else
:
else
:
grad_scaler
=
DynamicGradScaler
(
if
args
.
fp16
:
initial_scale
=
args
.
initial_loss_scale
,
grad_scaler
=
DynamicGradScaler
(
min_scale
=
args
.
min_loss_scale
,
initial_scale
=
args
.
initial_loss_scale
,
growth_factor
=
2.0
,
min_scale
=
args
.
min_loss_scale
,
backoff_factor
=
0.5
,
growth_factor
=
2.0
,
growth_interval
=
args
.
loss_scale_window
,
backoff_factor
=
0.5
,
hysteresis
=
args
.
hysteresis
)
growth_interval
=
args
.
loss_scale_window
,
hysteresis
=
args
.
hysteresis
)
# Megatron optimizer.
# Megatron optimizer.
return
FP16OptimizerWithFP16Params
(
optimizer
,
grad_scaler
,
return
Float16OptimizerWithFloat16Params
(
optimizer
,
args
.
clip_grad
)
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
bf16
,
grad_scaler
)
# FP32.
# FP32.
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
)
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
)
megatron/optimizer/clip_grads.py
View file @
ee7b19e7
...
@@ -118,3 +118,31 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
...
@@ -118,3 +118,31 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
clip_coeff
)
clip_coeff
)
return
total_norm
return
total_norm
def
count_zeros_fp32
(
parameters
):
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros
=
0.0
for
param
in
parameters
:
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
grad
=
param
.
grad
.
detach
()
num_zeros
=
grad
.
numel
()
-
torch
.
count_nonzero
(
grad
)
total_num_zeros
=
num_zeros
+
total_num_zeros
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
mpu
.
get_model_parallel_group
())
total_num_zeros
=
total_num_zeros
.
item
()
return
total_num_zeros
megatron/optimizer/optimizer.py
View file @
ee7b19e7
...
@@ -27,7 +27,7 @@ from megatron import get_timers
...
@@ -27,7 +27,7 @@ from megatron import get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
.clip_grads
import
clip_grad_norm_fp32
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
def
_zero_grad_group_helper
(
group
,
set_to_none
):
def
_zero_grad_group_helper
(
group
,
set_to_none
):
...
@@ -46,49 +46,77 @@ def _zero_grad_group_helper(group, set_to_none):
...
@@ -46,49 +46,77 @@ def _zero_grad_group_helper(group, set_to_none):
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
"""Use multi-tensor-applier to copy values from one list to another."""
"""Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
if
overflow_buf
:
if
overflow_buf
:
overflow_buf
.
fill_
(
0
)
overflow_buf
.
fill_
(
0
)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
overflow_buf
,
[
this
,
that
],
1.0
)
else
:
else
:
overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
for
this_
,
that_
in
zip
(
this
,
that
):
# Scaling with factor `1.0` is equivalent to copy.
that_
.
copy_
(
this_
)
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
overflow_buf
,
[
this
,
that
],
1.0
)
class
MegatronOptimizer
(
ABC
):
class
MegatronOptimizer
(
ABC
):
def
__init__
(
self
,
optimizer
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
):
"""Input optimizer is the base optimizer for example Adam."""
"""Input optimizer is the base optimizer for example Adam."""
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
assert
self
.
optimizer
,
'no optimizer is provided.'
assert
self
.
optimizer
,
'no optimizer is provided.'
# Set gradient clipping and logging params.
self
.
clip_grad
=
clip_grad
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
self
.
params_have_main_grad
=
params_have_main_grad
def
clip_grad_norm
(
self
,
clip_grad
):
def
get_parameters
(
self
):
params
=
[]
params
=
[]
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
for
param
in
param_group
[
'params'
]:
params
.
append
(
param
)
params
.
append
(
param
)
return
params
def
clip_grad_norm
(
self
,
clip_grad
):
params
=
self
.
get_parameters
()
return
clip_grad_norm_fp32
(
params
,
clip_grad
)
return
clip_grad_norm_fp32
(
params
,
clip_grad
)
def
count_zeros
(
self
):
params
=
self
.
get_parameters
()
return
count_zeros_fp32
(
params
)
@
abstractmethod
@
abstractmethod
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
pass
pass
@
abstractmethod
@
abstractmethod
def
get_loss_scale
(
self
):
def
get_loss_scale
(
self
):
"""The output should be a cuda tensor of size 1."""
"""The output should be a cuda tensor of size 1."""
pass
pass
def
scale_loss
(
self
,
loss
):
def
scale_loss
(
self
,
loss
):
"""Simple scaling."""
"""Simple scaling."""
return
self
.
get_loss_scale
()
*
loss
return
self
.
get_loss_scale
()
*
loss
@
abstractmethod
@
abstractmethod
def
step
(
self
):
def
step
(
self
):
pass
pass
@
abstractmethod
@
abstractmethod
def
reload_model_params
(
self
):
def
reload_model_params
(
self
):
"""Refreshes any internal state from the current model parameters.
"""Refreshes any internal state from the current model parameters.
...
@@ -98,14 +126,17 @@ class MegatronOptimizer(ABC):
...
@@ -98,14 +126,17 @@ class MegatronOptimizer(ABC):
with main parameters, the main parameters need to also be updated."""
with main parameters, the main parameters need to also be updated."""
pass
pass
@
abstractmethod
@
abstractmethod
def
state_dict
(
self
):
def
state_dict
(
self
):
pass
pass
@
abstractmethod
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
):
pass
pass
# Promote state so it can be retrieved or set via
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
# "optimizer_instance.state"
def
_get_state
(
self
):
def
_get_state
(
self
):
...
@@ -116,6 +147,7 @@ class MegatronOptimizer(ABC):
...
@@ -116,6 +147,7 @@ class MegatronOptimizer(ABC):
state
=
property
(
_get_state
,
_set_state
)
state
=
property
(
_get_state
,
_set_state
)
# Promote param_groups so it can be retrieved or set via
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
# (for example, to adjust the learning rate)
...
@@ -129,49 +161,90 @@ class MegatronOptimizer(ABC):
...
@@ -129,49 +161,90 @@ class MegatronOptimizer(ABC):
class
FP16OptimizerWithFP16Params
(
MegatronOptimizer
):
class
Float16OptimizerWithFloat16Params
(
MegatronOptimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
def
__init__
(
self
,
optimizer
,
grad_scaler
,
clip_grad
):
super
(
FP16OptimizerWithFP16Params
,
self
).
__init__
(
optimizer
)
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
bf16
,
grad_scaler
):
super
(
Float16OptimizerWithFloat16Params
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
)
self
.
bf16
=
bf16
self
.
grad_scaler
=
grad_scaler
self
.
grad_scaler
=
grad_scaler
self
.
clip_grad
=
clip_grad
# None grad scaler is only supported for bf16.
if
self
.
grad_scaler
is
None
:
assert
self
.
bf16
,
'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
# Any non-zero value indicates inf/nan.
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
])
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if
self
.
grad_scaler
:
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
])
# Dummy tensor needed for apex multi-apply tensor.
# Dummy tensor needed for apex multi-apply tensor.
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if
bf16
:
self
.
_dummy_overflow_buf
=
None
else
:
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# In case grad scaler is not passed, define the unity scale.
if
self
.
grad_scaler
is
None
:
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
# ======================
# ======================
# main parameter stuff
# main parameter stuff
# ======================
# ======================
# Three groups of parameters:
# Three groups of parameters:
# f
p
16_groups: original f
p
16 parameters
# f
loat
16_groups: original f
loat
16 parameters
# fp32_from_f
p
16_groups: fp32 copy of f
p
16 parameters
# fp32_from_f
loat
16_groups: fp32 copy of f
loat
16 parameters
# fp32_from_fp32_groups: original fp32 parameters
# fp32_from_fp32_groups: original fp32 parameters
self
.
f
p
16_groups
=
[]
self
.
f
loat
16_groups
=
[]
self
.
fp32_from_f
p
16_groups
=
[]
self
.
fp32_from_f
loat
16_groups
=
[]
self
.
fp32_from_fp32_groups
=
[]
self
.
fp32_from_fp32_groups
=
[]
# For all the groups in the original optimizer:
# For all the groups in the original optimizer:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param_group
in
self
.
optimizer
.
param_groups
:
f
p
16_params_this_group
=
[]
f
loat
16_params_this_group
=
[]
fp32_params_this_group
=
[]
fp32_params_this_group
=
[]
fp32_from_f
p
16_params_this_group
=
[]
fp32_from_f
loat
16_params_this_group
=
[]
# For all the parameters in this group:
# For all the parameters in this group:
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
for
i
,
param
in
enumerate
(
param_group
[
'params'
]):
if
param
.
requires_grad
:
if
param
.
requires_grad
:
# fp16 params:
# float16 params:
if
param
.
type
()
==
'torch.cuda.HalfTensor'
:
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
fp16_params_this_group
.
append
(
param
)
'torch.cuda.BFloat16Tensor'
]:
float16_params_this_group
.
append
(
param
)
# Create a copy
# Create a copy
main_param
=
param
.
detach
().
clone
().
float
()
main_param
=
param
.
detach
().
clone
().
float
()
# Store grads
main_param
.
requires_grad
=
True
# Copy tensor model parallel attributes.
# Copy tensor model parallel attributes.
mpu
.
copy_tensor_model_parallel_attributes
(
main_param
,
mpu
.
copy_tensor_model_parallel_attributes
(
main_param
,
param
)
param
)
...
@@ -179,7 +252,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -179,7 +252,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
main_param
.
shared
=
param
.
shared
main_param
.
shared
=
param
.
shared
# Replace the optimizer params with the new fp32 copy.
# Replace the optimizer params with the new fp32 copy.
param_group
[
'params'
][
i
]
=
main_param
param_group
[
'params'
][
i
]
=
main_param
fp32_from_f
p
16_params_this_group
.
append
(
main_param
)
fp32_from_f
loat
16_params_this_group
.
append
(
main_param
)
# Reset existing state dict key to the new main param.
# Reset existing state dict key to the new main param.
if
param
in
self
.
optimizer
.
state
:
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
self
.
optimizer
.
state
[
main_param
]
\
...
@@ -191,13 +264,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -191,13 +264,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
param_group
[
'params'
][
i
]
=
param
param_group
[
'params'
][
i
]
=
param
else
:
else
:
raise
TypeError
(
"Wrapped parameters must be either "
raise
TypeError
(
'Wrapped parameters must be one of '
"torch.cuda.FloatTensor or "
'torch.cuda.FloatTensor, '
"torch.cuda.HalfTensor. "
'torch.cuda.HalfTensor, or '
"Received {}"
.
format
(
param
.
type
()))
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
()))
self
.
fp16_groups
.
append
(
fp16_params_this_group
)
self
.
fp32_from_fp16_groups
.
append
(
fp32_from_fp16_params_this_group
)
self
.
float16_groups
.
append
(
float16_params_this_group
)
self
.
fp32_from_float16_groups
.
append
(
fp32_from_float16_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
# Leverage state_dict() and load_state_dict() to
# Leverage state_dict() and load_state_dict() to
...
@@ -207,37 +282,40 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -207,37 +282,40 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
"""We only need to zero the model related parameters, i.e.,
f
p
16_groups & fp32_from_fp32_groups."""
f
loat
16_groups & fp32_from_fp32_groups."""
for
group
in
self
.
f
p
16_groups
:
for
group
in
self
.
f
loat
16_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
_zero_grad_group_helper
(
group
,
set_to_none
)
for
group
in
self
.
fp32_from_fp32_groups
:
for
group
in
self
.
fp32_from_fp32_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
_zero_grad_group_helper
(
group
,
set_to_none
)
def
get_loss_scale
(
self
):
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
grad_scaler
.
scale
return
self
.
grad_scaler
.
scale
def
_copy_model_grads_to_main_grads
(
self
):
def
_copy_model_grads_to_main_grads
(
self
):
# This only needs to be done for the fp16 group.
# This only needs to be done for the float16 group.
model_grads
=
[]
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
main_grads
=
[]
self
.
fp32_from_float16_groups
):
for
model_group
,
main_group
in
zip
(
self
.
fp16_groups
,
self
.
fp32_from_fp16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
if
model_param
.
grad
is
not
None
:
if
self
.
params_have_main_grad
:
if
main_param
.
grad
is
None
:
main_param
.
grad
=
model_param
.
main_grad
.
float
()
main_param
.
grad
=
torch
.
empty_like
(
main_param
)
else
:
model_grads
.
append
(
model_param
.
grad
.
data
)
if
model_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
main_param
.
grad
=
model_param
.
grad
.
float
()
_multi_tensor_copy_this_to_that
(
this
=
model_grads
,
that
=
main_grads
,
# For fp32 grads, we need to reset the grads to main grad.
overflow_buf
=
self
.
_dummy_overflow_buf
)
if
self
.
params_have_main_grad
:
for
model_group
in
self
.
fp32_from_fp32_groups
:
for
model_param
in
model_group
:
model_param
.
grad
=
model_param
.
main_grad
def
_unscale_main_grads_and_check_for_nan
(
self
):
def
_unscale_main_grads_and_check_for_nan
(
self
):
main_grads
=
[]
main_grads
=
[]
# fp32 params fromm f
p
16 ones.
# fp32 params fromm f
loat
16 ones.
for
main_group
in
self
.
fp32_from_f
p
16_groups
:
for
main_group
in
self
.
fp32_from_f
loat
16_groups
:
for
main_param
in
main_group
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
main_grads
.
append
(
main_param
.
grad
.
data
)
...
@@ -261,11 +339,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -261,11 +339,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return
found_inf_flag
return
found_inf_flag
def
_get_model_and_main_params_data_f
p
16
(
self
):
def
_get_model_and_main_params_data_f
loat
16
(
self
):
model_data
=
[]
model_data
=
[]
main_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
f
p
16_groups
,
for
model_group
,
main_group
in
zip
(
self
.
f
loat
16_groups
,
self
.
fp32_from_f
p
16_groups
):
self
.
fp32_from_f
loat
16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
main_data
.
append
(
main_param
.
data
)
...
@@ -273,15 +351,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -273,15 +351,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def
_copy_main_params_to_model_params
(
self
):
def
_copy_main_params_to_model_params
(
self
):
# Only needed for the f
p
16 params.
# Only needed for the f
loat
16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
p
16
()
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
loat
16
()
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
_copy_model_params_to_main_params
(
self
):
def
_copy_model_params_to_main_params
(
self
):
# Only needed for the f
p
16 params.
# Only needed for the f
loat
16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
p
16
()
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_f
loat
16
()
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
overflow_buf
=
self
.
_dummy_overflow_buf
)
...
@@ -300,18 +378,22 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -300,18 +378,22 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
self
.
_copy_model_grads_to_main_grads
()
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# Unscale and check for inf/nan.
# Do unscale, check for inf, and update grad scaler only for
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
# the case that grad scaler is provided.
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
if
self
.
grad_scaler
:
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# Unscale and check for inf/nan.
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# We are done with scaling gradients
# We are done with scaling gradients
# so we can update the loss scale.
# so we can update the loss scale.
self
.
grad_scaler
.
update
(
found_inf_flag
)
self
.
grad_scaler
.
update
(
found_inf_flag
)
# If we found inf/nan, skip the update.
# If we found inf/nan, skip the update.
if
found_inf_flag
:
if
found_inf_flag
:
return
False
,
None
return
False
,
None
,
None
# Clip the main gradients.
# Clip the main gradients.
timers
(
'optimizer-clip-main-grad'
).
start
()
timers
(
'optimizer-clip-main-grad'
).
start
()
...
@@ -320,6 +402,10 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -320,6 +402,10 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
timers
(
'optimizer-clip-main-grad'
).
stop
()
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
# Step the optimizer.
# Step the optimizer.
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
...
@@ -329,14 +415,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -329,14 +415,15 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
# Successful update.
# Successful update.
return
True
,
grad_norm
return
True
,
grad_norm
,
num_zeros_in_grad
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{}
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
if
self
.
grad_scaler
:
state_dict
[
'fp32_from_fp16_params'
]
=
self
.
fp32_from_fp16_groups
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'fp32_from_fp16_params'
]
=
self
.
fp32_from_float16_groups
return
state_dict
return
state_dict
...
@@ -354,15 +441,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -354,15 +441,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
'load grad scaler ...'
)
else
:
else
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
print_rank_0
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
# Copy data for the main params.
fp32_from_f
p
16_params_key
=
'fp32_from_fp16_params'
fp32_from_f
loat
16_params_key
=
'fp32_from_fp16_params'
if
fp32_from_f
p
16_params_key
not
in
state_dict
:
if
fp32_from_f
loat
16_params_key
not
in
state_dict
:
fp32_from_f
p
16_params_key
=
'fp32_from_fp16'
fp32_from_f
loat
16_params_key
=
'fp32_from_fp16'
for
current_group
,
saved_group
in
zip
(
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_f
p
16_groups
,
self
.
fp32_from_f
loat
16_groups
,
state_dict
[
fp32_from_f
p
16_params_key
]):
state_dict
[
fp32_from_f
loat
16_params_key
]):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
current_param
.
data
.
copy_
(
saved_param
.
data
)
...
@@ -370,10 +462,14 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
...
@@ -370,10 +462,14 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
class
FP32Optimizer
(
MegatronOptimizer
):
class
FP32Optimizer
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
clip_grad
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
):
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
)
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
)
self
.
clip_grad
=
clip_grad
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
1.0
])
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
1.0
])
...
@@ -393,16 +489,26 @@ class FP32Optimizer(MegatronOptimizer):
...
@@ -393,16 +489,26 @@ class FP32Optimizer(MegatronOptimizer):
"""Clip gradients (if needed) and step the base optimizer.
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
Always return successful since there is no overflow."""
# Copy main_grads to grads.
if
self
.
params_have_main_grad
:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
param
.
grad
=
param
.
main_grad
# Clip gradients.
# Clip gradients.
grad_norm
=
None
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
# Update parameters.
# Update parameters.
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
# No overflow for FP32 optimizer.
# No overflow for FP32 optimizer.
return
True
,
grad_norm
return
True
,
grad_norm
,
num_zeros_in_grad
def
reload_model_params
(
self
):
def
reload_model_params
(
self
):
...
...
megatron/p2p_communication.py
0 → 100644
View file @
ee7b19e7
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
reduce
import
operator
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
def
_communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
use_ring_exchange
=
False
):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
args
=
get_args
()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
args
.
scatter_gather_tensors_in_pipeline
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
//
\
mpu
.
get_tensor_model_parallel_world_size
()
else
:
tensor_chunk_shape
=
tensor_shape
dtype
=
args
.
params_dtype
if
args
.
fp32_residual_connection
:
dtype
=
torch
.
float
if
recv_prev
:
tensor_recv_prev
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
if
recv_next
:
tensor_recv_next
=
torch
.
empty
(
tensor_chunk_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
# Split tensor into smaller chunks if using scatter-gather optimization.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
if
tensor_send_prev
is
not
None
:
tensor_send_prev
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_prev
)
# Send tensors in both the forward and backward directions as appropriate.
if
use_ring_exchange
:
torch
.
distributed
.
ring_exchange
(
tensor_send_prev
=
tensor_send_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_recv_next
=
tensor_recv_next
,
group
=
mpu
.
get_pipeline_model_parallel_group
())
else
:
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
# If using scatter-gather optimization, gather smaller chunks.
if
args
.
scatter_gather_tensors_in_pipeline
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
if
recv_next
:
tensor_recv_next
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_next
).
view
(
tensor_shape
).
requires_grad_
()
return
tensor_recv_prev
,
tensor_recv_next
def
recv_forward
(
timers
=
None
):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
if
timers
is
not
None
:
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
True
,
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'forward-recv'
).
stop
()
return
input_tensor
def
recv_backward
(
timers
=
None
):
"""Receive tensor from next rank in pipeline (backward receive)."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
if
timers
is
not
None
:
timers
(
'backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
)
if
timers
is
not
None
:
timers
(
'backward-recv'
).
stop
()
return
output_tensor_grad
def
send_forward
(
output_tensor
,
timers
=
None
):
"""Send tensor to next rank in pipeline (forward send)."""
if
not
mpu
.
is_pipeline_last_stage
():
if
timers
is
not
None
:
timers
(
'forward-send'
).
start
()
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'forward-send'
).
stop
()
def
send_backward
(
input_tensor_grad
,
timers
=
None
):
"""Send tensor to previous rank in pipeline (backward send)."""
if
not
mpu
.
is_pipeline_first_stage
():
if
timers
is
not
None
:
timers
(
'backward-send'
).
start
()
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'backward-send'
).
stop
()
def
send_forward_recv_backward
(
output_tensor
,
timers
=
None
):
"""Batched send and recv with next rank in pipeline."""
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
False
,
recv_next
=
True
)
if
timers
is
not
None
:
timers
(
'forward-send-backward-recv'
).
stop
()
return
output_tensor_grad
def
send_backward_recv_forward
(
input_tensor_grad
,
timers
=
None
):
"""Batched send and recv with previous rank in pipeline."""
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
True
,
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'backward-send-forward-recv'
).
stop
()
return
input_tensor
def
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
=
None
):
"""Batched recv from previous rank and send to next rank in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
start
()
input_tensor
,
_
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_prev
=
recv_prev
,
recv_next
=
False
)
if
timers
is
not
None
:
timers
(
'forward-send-forward-recv'
).
stop
()
return
input_tensor
def
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
=
None
):
"""Batched recv from next rank and send to previous rank in pipeline."""
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
False
,
recv_next
=
recv_next
)
if
timers
is
not
None
:
timers
(
'backward-send-backward-recv'
).
stop
()
return
output_tensor_grad
def
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
,
recv_next
,
timers
=
None
):
"""Batched send and recv with previous and next ranks in pipeline."""
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
start
()
input_tensor
,
output_tensor_grad
=
_communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
)
if
timers
is
not
None
:
timers
(
'forward-backward-send-forward-backward-recv'
).
stop
()
return
input_tensor
,
output_tensor_grad
megatron/schedules.py
0 → 100644
View file @
ee7b19e7
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
contextlib
import
contextmanager
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
get_num_microbatches
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
p2p_communication
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
timers
=
get_timers
()
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
timers
(
'forward-compute'
).
stop
()
return
output_tensor
def
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
args
=
get_args
()
timers
=
get_timers
()
timers
(
'backward-compute'
).
start
()
# Retain the grad on the input_tensor.
if
input_tensor
is
not
None
:
input_tensor
.
retain_grad
()
# Backward pass.
if
output_tensor_grad
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
# Collect the grad of the input_tensor.
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
timers
(
'backward-compute'
).
stop
()
return
input_tensor_grad
@
contextmanager
def
dummy_handler
():
try
:
yield
finally
:
pass
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
assert
len
(
model
)
==
1
model
=
model
[
0
]
context_handler
=
dummy_handler
if
isinstance
(
model
,
torchDDP
):
context_handler
=
model
.
no_sync
losses_reduced
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
with
context_handler
():
for
i
in
range
(
get_num_microbatches
()
-
1
):
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
losses_reduced
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
losses_reduced
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
pipeline_parallel_size
=
mpu
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
all_warmup_microbatches
=
False
if
forward_only
:
num_warmup_microbatches
=
num_microbatches
else
:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if
get_num_microbatches
()
==
pipeline_parallel_size
:
num_warmup_microbatches
=
num_microbatches
all_warmup_microbatches
=
True
else
:
num_warmup_microbatches
=
\
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
pipeline_parallel_size
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
def
get_model_chunk_id
(
microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group
=
microbatch_id
%
(
pipeline_parallel_size
*
num_model_chunks
)
model_chunk_id
=
microbatch_id_in_group
//
pipeline_parallel_size
if
not
forward
:
model_chunk_id
=
(
num_model_chunks
-
model_chunk_id
-
1
)
return
model_chunk_id
def
forward_step_helper
(
microbatch_id
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
\
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
return
output_tensor
def
backward_step_helper
(
microbatch_id
):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id
=
get_model_chunk_id
(
microbatch_id
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
if
mpu
.
is_pipeline_last_stage
():
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
input_tensor_grad
# Run warmup forward passes.
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
timers
))
for
k
in
range
(
num_warmup_microbatches
):
output_tensor
=
forward_step_helper
(
k
)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
recv_prev
=
True
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
if
next_forward_model_chunk_id
==
0
:
recv_prev
=
False
if
k
==
(
num_microbatches
-
1
):
recv_prev
=
False
# Don't send tensor downstream if on last stage.
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if
k
==
(
num_warmup_microbatches
-
1
)
and
not
forward_only
and
\
not
all_warmup_microbatches
:
input_tensor_grad
=
None
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
input_tensor
,
output_tensor_grad
=
\
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
input_tensor
=
\
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
# Run 1F1B in steady state.
for
k
in
range
(
num_microbatches_remaining
):
# Forward pass.
forward_k
=
k
+
num_warmup_microbatches
output_tensor
=
forward_step_helper
(
forward_k
)
# Backward pass.
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
None
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
mpu
.
is_pipeline_first_stage
():
input_tensor_grad
=
None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev
=
True
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
True
)
if
next_forward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_prev
=
False
next_forward_model_chunk_id
+=
1
else
:
next_forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
+
1
,
forward
=
True
)
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
-
(
pipeline_parallel_size
-
1
),
forward
=
False
)
if
next_backward_model_chunk_id
==
0
:
recv_next
=
False
next_backward_model_chunk_id
-=
1
else
:
next_backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
+
1
,
forward
=
False
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if
k
==
(
num_microbatches_remaining
-
1
):
recv_prev
=
False
# Communicate tensors.
input_tensor
,
output_tensor_grad
=
\
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grad
)
# Run cooldown backward passes (flush out pipeline).
if
not
forward_only
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
p2p_communication
.
recv_backward
(
timers
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
recv_next
=
True
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
next_backward_model_chunk_id
==
(
num_model_chunks
-
1
):
recv_next
=
False
if
k
==
(
num_microbatches
-
1
):
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
))
return
losses_reduced
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
timers
=
get_timers
()
assert
len
(
model
)
==
1
model
=
model
[
0
]
# Compute number of warmup microbatches.
num_microbatches
=
get_num_microbatches
()
num_warmup_microbatches
=
\
(
mpu
.
get_pipeline_model_parallel_world_size
()
-
mpu
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
p2p_communication
.
send_forward
(
output_tensor
,
timers
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
if
forward_only
:
p2p_communication
.
send_forward
(
output_tensor
,
timers
)
else
:
output_tensor_grad
=
\
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
timers
)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
if
forward_only
:
if
not
last_iteration
:
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
else
:
input_tensor
,
output_tensor
=
input_tensors
.
pop
(
0
),
output_tensors
.
pop
(
0
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
if
last_iteration
:
input_tensor
=
None
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
)
else
:
input_tensor
=
\
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
timers
)
# Run cooldown backward passes.
if
not
forward_only
:
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
p2p_communication
.
recv_backward
(
timers
)
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
)
return
losses_reduced
megatron/training.py
View file @
ee7b19e7
...
@@ -37,20 +37,23 @@ from megatron import print_rank_0
...
@@ -37,20 +37,23 @@ from megatron import print_rank_0
from
megatron
import
print_rank_last
from
megatron
import
print_rank_last
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model
import
F
P
16Module
from
megatron.model
import
F
loat
16Module
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.learning_rates
import
AnnealingLR
from
megatron.learning_rates
import
AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
unwrap_model
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.utils
import
calc_params_l2_norm
from
megatron.utils
import
calc_params_l2_norm
from
megatron.schedules
import
forward_backward_no_pipelining
from
megatron.schedules
import
forward_backward_pipelining_without_interleaving
from
megatron.schedules
import
forward_backward_pipelining_with_interleaving
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
def
print_datetime
(
string
):
def
print_datetime
(
string
):
"""Note that this call will sync across all ranks."""
"""Note that this call will sync across all ranks."""
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
...
@@ -107,23 +110,32 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -107,23 +110,32 @@ def pretrain(train_valid_test_dataset_provider,
timers
=
get_timers
()
timers
=
get_timers
()
# Model, optimizer, and learning rate.
# Model, optimizer, and learning rate.
timers
(
'model
and
optimizer'
).
start
()
timers
(
'model
-
and
-
optimizer
-setup
'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
'model
and
optimizer'
).
stop
()
timers
(
'model
-
and
-
optimizer
-setup
'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
print_datetime
(
'after model, optimizer, and learning rate '
'scheduler are built'
)
'scheduler are built'
)
# Data stuff.
# Data stuff.
timers
(
'train/valid/test data iterators'
).
start
()
timers
(
'train/valid/test-data-iterators-setup'
).
start
()
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
=
build_train_valid_test_data_iterators
(
all_data_iterators
=
[
train_valid_test_dataset_provider
)
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
timers
(
'train/valid/test data iterators'
).
stop
()
for
_
in
range
(
len
(
model
))
]
train_data_iterator
=
[
data_iterators
[
0
]
for
data_iterators
in
all_data_iterators
]
valid_data_iterator
=
[
data_iterators
[
1
]
for
data_iterators
in
all_data_iterators
]
test_data_iterator
=
[
data_iterators
[
2
]
for
data_iterators
in
all_data_iterators
]
else
:
train_data_iterator
,
valid_data_iterator
,
test_data_iterator
\
=
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
timers
(
'train/valid/test-data-iterators-setup'
).
stop
()
print_datetime
(
'after dataloaders are built'
)
print_datetime
(
'after dataloaders are built'
)
# Print setup timing.
# Print setup timing.
print_rank_0
(
'done with setup
s
...'
)
print_rank_0
(
'done with setup ...'
)
timers
.
log
([
'model
and
optimizer'
,
'train/valid/test
data
iterators'
])
timers
.
log
([
'model
-
and
-
optimizer
-setup
'
,
'train/valid/test
-
data
-
iterators
-setup
'
])
print_rank_0
(
'training ...'
)
print_rank_0
(
'training ...'
)
iteration
=
0
iteration
=
0
...
@@ -185,13 +197,16 @@ def get_model(model_provider_func):
...
@@ -185,13 +197,16 @@ def get_model(model_provider_func):
# Build model on cpu.
# Build model on cpu.
model
=
model_provider_func
()
model
=
model_provider_func
()
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Set tensor model parallel attributes if not set.
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# attributes set for them. We should make sure the default attributes
# are set for all params so the optimizer can use them.
# are set for all params so the optimizer can use them.
for
param
in
model
.
parameters
():
for
model_module
in
model
:
mpu
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
for
param
in
model_module
.
parameters
():
mpu
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
...
@@ -199,22 +214,29 @@ def get_model(model_provider_func):
...
@@ -199,22 +214,29 @@ def get_model(model_provider_func):
'model parallel rank ({}, {}): {}'
.
format
(
'model parallel rank ({}, {}): {}'
.
format
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
sum
([
p
.
nelement
()
for
p
in
model
.
parameters
()])),
flush
=
True
)
sum
([
sum
([
p
.
nelement
()
for
p
in
model_module
.
parameters
()])
for
model_module
in
model
])),
flush
=
True
)
# GPU allocation.
# GPU allocation.
model
.
cuda
(
torch
.
cuda
.
current_device
())
for
model_module
in
model
:
model_module
.
cuda
(
torch
.
cuda
.
current_device
())
# Fp16 conversion.
# Fp16 conversion.
if
args
.
fp16
:
if
args
.
fp16
or
args
.
bf16
:
model
=
FP
16Module
(
model
)
model
=
[
Float
16Module
(
model
_module
,
args
)
for
model_module
in
model
]
if
args
.
DDP_impl
==
'torch'
:
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
i
=
torch
.
cuda
.
current_device
()
model
=
torchDDP
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
model
=
[
torchDDP
(
model_module
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
process_group
=
mpu
.
get_data_parallel_group
())
for
model_module
in
model
]
return
model
return
model
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
model
=
LocalDDP
(
model
)
model
=
[
LocalDDP
(
model_module
,
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
use_contiguous_buffers_in_ddp
)
for
model_module
in
model
]
return
model
return
model
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
...
@@ -270,9 +292,8 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -270,9 +292,8 @@ def setup_model_and_optimizer(model_provider_func):
model
=
get_model
(
model_provider_func
)
model
=
get_model
(
model_provider_func
)
unwrapped_model
=
model
unwrapped_model
=
unwrap_model
(
model
,
while
isinstance
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16Module
)):
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
=
unwrapped_model
.
module
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
...
@@ -282,305 +303,29 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -282,305 +303,29 @@ def setup_model_and_optimizer(model_provider_func):
# Extra barrier is added to make sure all ranks report the
# Extra barrier is added to make sure all ranks report the
# max time.
# max time.
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'load
checkpoint'
).
start
()
timers
(
'load
-
checkpoint'
).
start
()
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'load
checkpoint'
).
stop
()
timers
(
'load
-
checkpoint'
).
stop
()
timers
.
log
([
'load
checkpoint'
])
timers
.
log
([
'load
-
checkpoint'
])
else
:
else
:
args
.
iteration
=
0
args
.
iteration
=
0
# We only support local DDP with multiple micro-batches.
# We only support local DDP with multiple micro-batches.
if
get_num_microbatches
()
>
1
:
if
len
(
model
)
>
1
or
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
DDP_impl
==
'local'
# get model without FP16 and/or TorchDDP wrappers
# get model without FP16 and/or TorchDDP wrappers
unwrapped_model
=
model
if
args
.
iteration
==
0
and
len
(
unwrapped_model
)
==
1
\
while
hasattr
(
unwrapped_model
,
'module'
):
and
hasattr
(
unwrapped_model
[
0
],
'init_state_dict_from_bert'
):
unwrapped_model
=
unwrapped_model
.
module
print_rank_0
(
"Initializing ICT from pretrained BERT model"
)
unwrapped_model
[
0
].
init_state_dict_from_bert
()
if
args
.
iteration
==
0
and
hasattr
(
unwrapped_model
,
if
args
.
fp16
:
'init_state_dict_from_bert'
):
optimizer
.
reload_model_params
()
print
(
"Initializing ICT from pretrained BERT model"
,
flush
=
True
)
unwrapped_model
.
init_state_dict_from_bert
()
return
model
,
optimizer
,
lr_scheduler
return
model
,
optimizer
,
lr_scheduler
def
communicate
(
tensor_send_next
,
tensor_send_prev
,
recv_forward
,
recv_backward
):
"""Communicate tensors between stages."""
args
=
get_args
()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
dtype
=
args
.
params_dtype
if
args
.
fp32_residual_connection
:
dtype
=
torch
.
float
if
recv_forward
:
tensor_recv_prev
=
torch
.
empty
(
tensor_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
if
recv_backward
:
tensor_recv_next
=
torch
.
empty
(
tensor_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
)
# Send tensors in both the forward and backward directions as appropriate.
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
mpu
.
get_pipeline_model_parallel_prev_rank
())
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
mpu
.
get_pipeline_model_parallel_next_rank
())
ops
.
append
(
recv_next_op
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
(
ops
)
for
req
in
reqs
:
req
.
wait
()
# Temporary workaround for batch_isend_irecv() race condition.
torch
.
cuda
.
synchronize
()
return
tensor_recv_prev
,
tensor_recv_next
def
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
):
"""Backward step."""
args
=
get_args
()
timers
=
get_timers
()
# Retain the grad on the input_tensor.
if
input_tensor
is
not
None
:
input_tensor
.
retain_grad
()
# Backward pass.
if
output_tensor_grad
is
None
:
output_tensor
=
optimizer
.
scale_loss
(
output_tensor
)
torch
.
autograd
.
backward
(
output_tensor
,
grad_tensors
=
output_tensor_grad
)
# Collect the grad of the input_tensor.
input_tensor_grad
=
None
if
input_tensor
is
not
None
:
input_tensor_grad
=
input_tensor
.
grad
return
input_tensor_grad
def
forward_step_with_communication
(
forward_step_func
,
data_iterator
,
model
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
):
args
=
get_args
()
if
not
mpu
.
is_pipeline_first_stage
():
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
timers
(
'forward-recv'
).
stop
()
else
:
input_tensor
=
None
# Forward model for one step.
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
timers
(
'forward-compute'
).
stop
()
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
else
:
timers
(
'forward-send'
).
start
()
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
timers
(
'forward-send'
).
stop
()
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
def
backward_step_with_communication
(
optimizer
,
model
,
input_tensors
,
output_tensors
,
timers
):
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
if
mpu
.
is_pipeline_last_stage
():
output_tensor_grad
=
None
else
:
timers
(
'backward-recv'
).
start
()
_
,
output_tensor_grad
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
True
)
timers
(
'backward-recv'
).
stop
()
# Backward pass for one step.
timers
(
'backward-compute'
).
start
()
input_grad_tensor
=
\
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
timers
(
'backward-compute'
).
stop
()
if
not
mpu
.
is_pipeline_first_stage
():
timers
(
'backward-send'
).
start
()
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_grad_tensor
,
recv_forward
=
False
,
recv_backward
=
False
)
timers
(
'backward-send'
).
stop
()
def
forward_and_backward_steps_with_communication
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
input_tensor
,
last_microbatch
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
):
args
=
get_args
()
# Forward model for one step.
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
timers
(
'forward-compute'
).
stop
()
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
output_tensor_grad
=
None
losses_reduced
.
append
(
loss_reduced
)
else
:
timers
(
'forward-send-backward-recv'
).
start
()
_
,
output_tensor_grad
=
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
True
)
timers
(
'forward-send-backward-recv'
).
stop
()
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
# Backward pass for one step.
timers
(
'backward-compute'
).
start
()
input_grad_tensor
=
\
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
timers
(
'backward-compute'
).
stop
()
if
not
mpu
.
is_pipeline_first_stage
():
timers
(
'backward-send-forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
input_grad_tensor
,
recv_forward
=
(
not
last_microbatch
),
recv_backward
=
False
)
timers
(
'backward-send-forward-recv'
).
stop
()
else
:
input_tensor
=
None
return
input_tensor
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
):
"""Run forward and backward passes without inter-stage communication."""
args
=
get_args
()
losses_reduced
=
[]
for
i
in
range
(
get_num_microbatches
()):
timers
(
'forward-compute'
).
start
()
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
=
None
)
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
timers
(
'forward-compute'
).
stop
()
timers
(
'backward-compute'
).
start
()
output_tensor_grad
=
None
backward_step
(
optimizer
,
model
,
input_tensor
=
None
,
output_tensor
=
output_tensor
,
output_tensor_grad
=
None
)
timers
(
'backward-compute'
).
stop
()
return
losses_reduced
def
forward_backward_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
args
=
get_args
()
# Compute number of warmup microbatches.
num_microbatches
=
get_num_microbatches
()
num_warmup_microbatches
=
\
(
mpu
.
get_pipeline_model_parallel_world_size
()
-
mpu
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
\
num_microbatches
-
num_warmup_microbatches
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
forward_step_with_communication
(
forward_step_func
,
data_iterator
,
model
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
else
:
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
timers
(
'forward-recv'
).
stop
()
# Run 1F1B.
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
input_tensor
=
\
forward_and_backward_steps_with_communication
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
input_tensor
,
last_iteration
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
)
# Run cooldown backward passes.
for
i
in
range
(
num_warmup_microbatches
):
backward_step_with_communication
(
optimizer
,
model
,
input_tensors
,
output_tensors
,
timers
)
return
losses_reduced
def
train_step
(
forward_step_func
,
data_iterator
,
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr_scheduler
):
model
,
optimizer
,
lr_scheduler
):
"""Single training step."""
"""Single training step."""
...
@@ -588,20 +333,31 @@ def train_step(forward_step_func, data_iterator,
...
@@ -588,20 +333,31 @@ def train_step(forward_step_func, data_iterator,
timers
=
get_timers
()
timers
=
get_timers
()
# Set grad to zero.
# Set grad to zero.
optimizer
.
zero_grad
()
if
args
.
DDP_impl
==
'local'
and
args
.
use_contiguous_buffers_in_ddp
:
for
partition
in
model
:
partition
.
zero_grad_buffer
()
else
:
optimizer
.
zero_grad
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
losses_reduced
=
forward_backward_pipelining
(
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
)
forward_backward_func
=
forward_backward_pipelining_with_interleaving
assert
get_num_microbatches
()
%
args
.
pipeline_model_parallel_size
==
0
,
\
'number of microbatches is not divisible by pipeline-parallel '
\
'size when using interleaved schedule'
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
else
:
losses_reduced
=
forward_backward_no_pipelining
(
forward_backward_func
=
forward_backward_no_pipelining
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
)
losses_reduced
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
=
False
)
# All-reduce if needed.
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
timers
(
'backward-params-all-reduce'
).
start
()
model
.
allreduce_params
(
reduce_after
=
False
,
for
model
_module
in
model
:
fp32_allreduce
=
args
.
fp32_allreduce
)
model_module
.
allreduce_gradients
(
)
timers
(
'backward-params-all-reduce'
).
stop
()
timers
(
'backward-params-all-reduce'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
# All-reduce word_embeddings' grad across first and last stages to ensure
...
@@ -609,25 +365,32 @@ def train_step(forward_step_func, data_iterator,
...
@@ -609,25 +365,32 @@ def train_step(forward_step_func, data_iterator,
# This should only run for models that support pipelined model parallelism
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
timers
(
'backward-embedding-all-reduce'
).
start
()
if
(
mpu
.
is_pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
())
and
\
if
(
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
or
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
))
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
unwrapped_model
=
model
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
while
isinstance
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16Module
)):
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrapped_model
.
module
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
unwrapped_model
.
share_word_embeddings
:
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
torch
.
distributed
.
all_reduce
(
word_embeddings_weight
.
grad
,
if
args
.
DDP_impl
==
'local'
:
group
=
mpu
.
get_embedding_group
())
grad
=
word_embeddings_weight
.
main_grad
else
:
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
timers
(
'backward-embedding-all-reduce'
).
stop
()
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
update_successful
l
,
grad_norm
=
optimizer
.
step
()
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
timers
(
'optimizer'
).
stop
()
# Update learning rate.
# Update learning rate.
if
update_successful
l
:
if
update_successful
:
increment
=
get_num_microbatches
()
*
\
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
args
.
data_parallel_size
...
@@ -636,19 +399,19 @@ def train_step(forward_step_func, data_iterator,
...
@@ -636,19 +399,19 @@ def train_step(forward_step_func, data_iterator,
else
:
else
:
skipped_iter
=
1
skipped_iter
=
1
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Average loss across microbatches.
# Average loss across microbatches.
loss_reduced
=
{}
loss_reduced
=
{}
for
key
in
losses_reduced
[
0
]:
for
key
in
losses_reduced
[
0
]:
losses_reduced_for_key
=
[
x
[
key
]
for
x
in
losses_reduced
]
losses_reduced_for_key
=
[
x
[
key
]
for
x
in
losses_reduced
]
loss_reduced
[
key
]
=
sum
(
losses_reduced_for_key
)
/
len
(
losses_reduced_for_key
)
loss_reduced
[
key
]
=
sum
(
losses_reduced_for_key
)
/
len
(
losses_reduced_for_key
)
return
loss_reduced
,
skipped_iter
,
grad_norm
return
loss_reduced
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
return
{},
skipped_iter
,
grad_norm
return
{},
skipped_iter
,
grad_norm
,
num_zeros_in_grad
def
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
iteration
,
def
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
iteration
,
loss_scale
,
report_memory_flag
,
skipped_iter
,
loss_scale
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
):
grad_norm
,
params_norm
,
num_zeros_in_grad
):
"""Log training information such as losses, timing, ...."""
"""Log training information such as losses, timing, ...."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -692,11 +455,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -692,11 +455,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging
(
'forward-compute'
)
add_to_logging
(
'forward-compute'
)
add_to_logging
(
'forward-recv'
)
add_to_logging
(
'forward-recv'
)
add_to_logging
(
'forward-send'
)
add_to_logging
(
'forward-send'
)
add_to_logging
(
'forward-
sen
d-backward-recv'
)
add_to_logging
(
'forward-
backward-send-forwar
d-backward-recv'
)
add_to_logging
(
'backward-compute'
)
add_to_logging
(
'backward-compute'
)
add_to_logging
(
'backward-recv'
)
add_to_logging
(
'backward-recv'
)
add_to_logging
(
'backward-send'
)
add_to_logging
(
'backward-send'
)
add_to_logging
(
'backward-send-forward-recv'
)
add_to_logging
(
'backward-send-forward-recv'
)
add_to_logging
(
'backward-send-backward-recv'
)
add_to_logging
(
'backward-params-all-reduce'
)
add_to_logging
(
'backward-params-all-reduce'
)
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'optimizer-copy-to-main-grad'
)
add_to_logging
(
'optimizer-copy-to-main-grad'
)
...
@@ -736,6 +500,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -736,6 +500,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer
.
add_scalar
(
'grad-norm'
,
grad_norm
,
iteration
)
writer
.
add_scalar
(
'grad-norm'
,
grad_norm
,
iteration
)
writer
.
add_scalar
(
'grad-norm vs samples'
,
grad_norm
,
writer
.
add_scalar
(
'grad-norm vs samples'
,
grad_norm
,
args
.
consumed_train_samples
)
args
.
consumed_train_samples
)
if
num_zeros_in_grad
is
not
None
:
writer
.
add_scalar
(
'num-zeros'
,
num_zeros_in_grad
,
iteration
)
writer
.
add_scalar
(
'num-zeros vs samples'
,
num_zeros_in_grad
,
args
.
consumed_train_samples
)
if
params_norm
is
not
None
:
if
params_norm
is
not
None
:
writer
.
add_scalar
(
'params-norm'
,
params_norm
,
iteration
)
writer
.
add_scalar
(
'params-norm'
,
params_norm
,
iteration
)
writer
.
add_scalar
(
'params-norm vs samples'
,
params_norm
,
writer
.
add_scalar
(
'params-norm vs samples'
,
params_norm
,
...
@@ -745,7 +513,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -745,7 +513,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
normalizer
=
total_iterations
)
normalizer
=
total_iterations
)
if
iteration
%
args
.
log_interval
==
0
:
if
iteration
%
args
.
log_interval
==
0
:
elapsed_time
=
timers
(
'interval
time'
).
elapsed
()
elapsed_time
=
timers
(
'interval
-
time'
).
elapsed
()
elapsed_time_per_iteration
=
elapsed_time
/
total_iterations
elapsed_time_per_iteration
=
elapsed_time
/
total_iterations
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
if
args
.
log_timers_to_tensorboard
:
if
args
.
log_timers_to_tensorboard
:
...
@@ -770,6 +538,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -770,6 +538,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string
+=
' loss scale: {:.1f} |'
.
format
(
loss_scale
)
log_string
+=
' loss scale: {:.1f} |'
.
format
(
loss_scale
)
if
grad_norm
is
not
None
:
if
grad_norm
is
not
None
:
log_string
+=
' grad norm: {:.3f} |'
.
format
(
grad_norm
)
log_string
+=
' grad norm: {:.3f} |'
.
format
(
grad_norm
)
if
num_zeros_in_grad
is
not
None
:
log_string
+=
' num zeros: {:.1f} |'
.
format
(
num_zeros_in_grad
)
if
params_norm
is
not
None
:
if
params_norm
is
not
None
:
log_string
+=
' params norm: {:.3f} |'
.
format
(
params_norm
)
log_string
+=
' params norm: {:.3f} |'
.
format
(
params_norm
)
log_string
+=
' number of skipped iterations: {:3d} |'
.
format
(
log_string
+=
' number of skipped iterations: {:3d} |'
.
format
(
...
@@ -794,11 +564,11 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
...
@@ -794,11 +564,11 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
# Extra barrier is added to make sure
# Extra barrier is added to make sure
# all ranks report the max time.
# all ranks report the max time.
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'save
checkpoint'
).
start
()
timers
(
'save
-
checkpoint'
).
start
()
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'save
checkpoint'
).
stop
()
timers
(
'save
-
checkpoint'
).
stop
()
timers
.
log
([
'save
checkpoint'
])
timers
.
log
([
'save
-
checkpoint'
])
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
...
@@ -811,7 +581,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -811,7 +581,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
write_args_to_tensorboard
()
write_args_to_tensorboard
()
# Turn on training mode which enables dropout.
# Turn on training mode which enables dropout.
model
.
train
()
for
model_module
in
model
:
model_module
.
train
()
# Tracking loss.
# Tracking loss.
total_loss_dict
=
{}
total_loss_dict
=
{}
...
@@ -819,16 +590,17 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -819,16 +590,17 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
# Iterations.
iteration
=
args
.
iteration
iteration
=
args
.
iteration
timers
(
'interval
time'
).
start
()
timers
(
'interval
-
time'
).
start
()
print_datetime
(
'before the start of training step'
)
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
update_num_microbatches
(
args
.
consumed_train_samples
)
update_num_microbatches
(
args
.
consumed_train_samples
)
loss_dict
,
skipped_iter
,
grad_norm
=
train_step
(
forward_step_func
,
loss_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
\
train_data_iterator
,
train_step
(
forward_step_func
,
model
,
train_data_iterator
,
optimizer
,
model
,
lr_scheduler
)
optimizer
,
lr_scheduler
)
iteration
+=
1
iteration
+=
1
args
.
consumed_train_samples
+=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
consumed_train_samples
+=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
args
.
micro_batch_size
*
\
...
@@ -843,7 +615,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -843,7 +615,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
optimizer
.
param_groups
[
0
][
'lr'
],
optimizer
.
param_groups
[
0
][
'lr'
],
iteration
,
loss_scale
,
iteration
,
loss_scale
,
report_memory_flag
,
skipped_iter
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
)
grad_norm
,
params_norm
,
num_zeros_in_grad
)
# Autoresume
# Autoresume
if
args
.
adlr_autoresume
and
\
if
args
.
adlr_autoresume
and
\
...
@@ -900,7 +672,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -900,7 +672,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args
=
get_args
()
args
=
get_args
()
# Turn on evaluation mode which disables dropout.
# Turn on evaluation mode which disables dropout.
model
.
eval
()
for
model_module
in
model
:
model_module
.
eval
()
total_loss_dict
=
{}
total_loss_dict
=
{}
...
@@ -912,37 +685,30 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -912,37 +685,30 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
args
.
eval_iters
))
args
.
eval_iters
))
for
_
in
range
(
get_num_microbatches
()):
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
not
mpu
.
is_pipeline_first_stage
():
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
input_tensor
,
_
=
communicate
(
forward_backward_func
=
forward_backward_pipelining_with_interleaving
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
else
:
input_tensor
=
None
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
# Forward evaluation.
forward_backward_func
=
forward_backward_no_pipelining
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
loss_dicts
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
=
None
,
if
mpu
.
is_pipeline_last_stage
():
timers
=
None
,
forward_only
=
True
)
_
,
loss_dict
=
output_tensor
# Reduce across processes.
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
# Reduce across processes.
for
loss_dict
in
loss_dicts
:
for
key
in
loss_dict
:
for
key
in
loss_dict
:
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
\
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
loss_dict
[
key
]
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
loss_dict
[
key
]
else
:
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
micro_batch_size
\
*
args
.
micro_batch_size
\
*
get_num_microbatches
()
*
get_num_microbatches
()
# Move model back to the train mode.
# Move model back to the train mode.
model
.
train
()
for
model_module
in
model
:
model_module
.
train
()
for
key
in
total_loss_dict
:
for
key
in
total_loss_dict
:
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get_num_microbatches
()
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get_num_microbatches
()
...
...
megatron/utils.py
View file @
ee7b19e7
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
import
sys
import
sys
import
torch
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
torchDDP
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
import
amp_C
...
@@ -26,20 +27,41 @@ from megatron import get_args
...
@@ -26,20 +27,41 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model.module
import
param_is_not_shared
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
def
unwrap_model
(
model
,
module_instances
=
(
torchDDP
)):
return_list
=
True
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
return_list
=
False
unwrapped_model
=
[]
for
model_module
in
model
:
while
isinstance
(
model_module
,
module_instances
):
model_module
=
model_module
.
module
unwrapped_model
.
append
(
model_module
)
if
not
return_list
:
return
unwrapped_model
[
0
]
return
unwrapped_model
def
calc_params_l2_norm
(
model
):
def
calc_params_l2_norm
(
model
):
"""Calculate l2 norm of parameters """
"""Calculate l2 norm of parameters """
args
=
get_args
()
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
# Remove duplicate params.
# Remove duplicate params.
params_data
=
[]
params_data
=
[]
for
param
in
model
.
parameters
():
for
model_
in
model
:
is_not_shared
=
param_is_not_shared
(
param
)
for
param
in
model_
.
parameters
():
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
is_not_shared
=
param_is_not_shared
(
param
)
if
is_not_shared
and
is_not_tp_duplicate
:
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
params_data
.
append
(
param
.
data
)
if
is_not_shared
and
is_not_tp_duplicate
:
if
args
.
bf16
:
params_data
.
append
(
param
.
data
.
float
())
else
:
params_data
.
append
(
param
.
data
)
# Calculate norm
# Calculate norm
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
norm
,
_
=
multi_tensor_applier
(
norm
,
_
=
multi_tensor_applier
(
...
@@ -106,6 +128,8 @@ def print_params_min_max_norm(optimizer, iteration):
...
@@ -106,6 +128,8 @@ def print_params_min_max_norm(optimizer, iteration):
def
check_adlr_autoresume_termination
(
iteration
,
model
,
def
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
optimizer
,
lr_scheduler
):
"""Check for autoresume signal and exit if it is received."""
"""Check for autoresume signal and exit if it is received."""
from
megatron.checkpointing
import
save_checkpoint
args
=
get_args
()
args
=
get_args
()
autoresume
=
get_adlr_autoresume
()
autoresume
=
get_adlr_autoresume
()
# Add barrier to ensure consistnecy.
# Add barrier to ensure consistnecy.
...
...
pretrain_bert.py
View file @
ee7b19e7
...
@@ -38,7 +38,7 @@ def model_provider():
...
@@ -38,7 +38,7 @@ def model_provider():
args
=
get_args
()
args
=
get_args
()
num_tokentypes
=
2
if
args
.
bert_binary_head
else
0
num_tokentypes
=
2
if
args
.
bert_binary_head
else
0
i
f
m
pu
.
get_pipeline_model_parallel_world_size
()
>
1
:
de
f
m
odel_provider_pipelined
()
:
# Determine model based on position of stage in pipeline.
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
model
=
BertModelFirstStage
(
model
=
BertModelFirstStage
(
...
@@ -51,6 +51,17 @@ def model_provider():
...
@@ -51,6 +51,17 @@ def model_provider():
else
:
else
:
model
=
BertModelIntermediateStage
(
model
=
BertModelIntermediateStage
(
num_tokentypes
=
num_tokentypes
)
num_tokentypes
=
num_tokentypes
)
return
model
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
.
append
(
model_provider_pipelined
())
else
:
model
=
model_provider_pipelined
()
else
:
else
:
model
=
BertModel
(
model
=
BertModel
(
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
...
@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch.
# Get the batch.
timers
(
'batch-generator'
).
start
()
timers
(
'batch-generator'
).
start
()
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
\
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
=
get_batch
(
=
get_batch
(
data_iterator
)
data_iterator
)
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
if
not
args
.
bert_binary_head
:
if
not
args
.
bert_binary_head
:
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment