Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3b91262e
Commit
3b91262e
authored
Mar 02, 2021
by
Vijay Korthikanti
Committed by
Jared Casper
Mar 23, 2021
Browse files
pipeline code simplification
parent
2f3a2d68
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
268 additions
and
631 deletions
+268
-631
megatron/model/__init__.py
megatron/model/__init__.py
+2
-10
megatron/model/bert_model.py
megatron/model/bert_model.py
+34
-84
megatron/model/classification.py
megatron/model/classification.py
+29
-71
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+25
-86
megatron/model/language_model.py
megatron/model/language_model.py
+41
-162
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+32
-73
megatron/model/transformer.py
megatron/model/transformer.py
+14
-5
megatron/schedules.py
megatron/schedules.py
+3
-1
megatron/training.py
megatron/training.py
+25
-7
pretrain_bert.py
pretrain_bert.py
+40
-76
pretrain_gpt.py
pretrain_gpt.py
+23
-56
No files found.
megatron/model/__init__.py
View file @
3b91262e
...
...
@@ -16,15 +16,7 @@
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.distributed
import
*
from
.bert_model
import
(
BertModel
,
BertModelFirstStage
,
BertModelIntermediateStage
,
BertModelLastStage
)
from
.gpt_model
import
(
GPTModel
,
GPTModelFirstStage
,
GPTModelIntermediateStage
,
GPTModelLastStage
)
from
.bert_model
import
BertModel
from
.gpt_model
import
GPTModel
from
.language_model
import
get_language_model
from
.module
import
Float16Module
megatron/model/bert_model.py
View file @
3b91262e
...
...
@@ -121,17 +121,23 @@ def post_language_model_processing(lm_output, pooled_output,
return
lm_loss
,
binary_logits
class
BertModel
Base
(
MegatronModule
):
class
BertModel
(
MegatronModule
):
"""Bert Language model."""
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
):
super
(
BertModelBase
,
self
).
__init__
()
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
BertModel
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
add_binary_head
=
add_binary_head
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
...
...
@@ -142,10 +148,12 @@ class BertModelBase(MegatronModule):
add_pooler
=
self
.
add_binary_head
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
self
.
lm_head
=
BertLMHead
(
self
.
word_embeddings_weight
().
size
(
0
),
args
.
hidden_size
,
init_method
,
args
.
layernorm_epsilon
,
parallel_output
)
...
...
@@ -156,26 +164,29 @@ class BertModelBase(MegatronModule):
init_method
)
self
.
_binary_head_key
=
'binary_head'
def
set_input_tensor
(
self
,
input_tensor
):
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
bert_model_input
,
attention_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
kwargs
=
{}
if
mpu
.
is_pipeline_first_stage
():
input_ids
=
bert_model_input
position_ids
=
bert_position_ids
(
input_ids
)
args
=
[
input_ids
,
position_ids
,
extended_attention_mask
]
kwargs
[
'tokentype_ids'
]
=
tokentype_ids
else
:
args
=
[
bert_model_input
,
extended_attention_mask
]
lm_output
=
self
.
language_model
(
*
args
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_binary_head
:
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
if
self
.
post_process
and
self
.
add_binary_head
:
lm_output
,
pooled_output
=
lm_output
else
:
pooled_output
=
None
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
pooled_output
,
self
.
lm_head
,
self
.
binary_head
,
lm_labels
,
...
...
@@ -194,15 +205,15 @@ class BertModelBase(MegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_binary_head
:
if
self
.
post_process
and
self
.
add_binary_head
:
state_dict_
[
self
.
_binary_head_key
]
\
=
self
.
binary_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
# Save word_embeddings.
if
mpu
.
is_pipeline_last_stage
()
and
not
mpu
.
is_pipeline_first_stage
()
:
if
self
.
post_process
and
not
self
.
pre_process
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
...
...
@@ -212,74 +223,13 @@ class BertModelBase(MegatronModule):
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_binary_head
:
if
self
.
post_process
and
self
.
add_binary_head
:
self
.
binary_head
.
load_state_dict
(
state_dict
[
self
.
_binary_head_key
],
strict
=
strict
)
# Load word_embeddings.
if
mpu
.
is_pipeline_last_stage
()
and
not
mpu
.
is_pipeline_first_stage
()
:
if
self
.
post_process
and
not
self
.
pre_process
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
class
BertModel
(
BertModelBase
):
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
):
super
(
BertModel
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
add_binary_head
,
parallel_output
=
parallel_output
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
):
return
super
(
BertModel
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
lm_labels
=
lm_labels
)
class
BertModelFirstStage
(
BertModelBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
BertModelFirstStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
BertModelFirstStage
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
BertModelIntermediateStage
(
BertModelBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
BertModelIntermediateStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
BertModelIntermediateStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
class
BertModelLastStage
(
BertModelBase
):
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
):
super
(
BertModelLastStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
add_binary_head
,
parallel_output
=
parallel_output
)
def
forward
(
self
,
hidden_state
,
attention_mask
,
lm_labels
=
None
):
return
super
(
BertModelLastStage
,
self
).
forward
(
hidden_state
,
attention_mask
,
lm_labels
=
lm_labels
)
megatron/model/classification.py
View file @
3b91262e
...
...
@@ -28,13 +28,19 @@ from megatron.model.utils import scaled_init_method_normal
from
.module
import
MegatronModule
class
ClassificationBase
(
MegatronModule
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
ClassificationBase
,
self
).
__init__
(
share_word_embeddings
=
False
)
class
Classification
(
MegatronModule
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
,
pre_process
=
True
,
post_process
=
True
):
super
(
Classification
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
self
.
num_classes
=
num_classes
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
...
...
@@ -43,31 +49,35 @@ class ClassificationBase(MegatronModule):
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
# Multi-choice head.
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
self
.
classification_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
classification_head
=
get_linear_layer
(
args
.
hidden_size
,
self
.
num_classes
,
init_method
)
self
.
_classification_head_key
=
'classification_head'
def
set_input_tensor
(
self
,
input_tensor
):
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
model_input
,
attention_mask
,
tokentype_ids
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
kwargs
=
{}
if
mpu
.
is_pipeline_first_stage
():
input_ids
=
model_input
position_ids
=
bert_position_ids
(
input_ids
)
args
=
[
input_ids
,
position_ids
,
extended_attention_mask
]
kwargs
[
'tokentype_ids'
]
=
tokentype_ids
else
:
args
=
[
model_input
,
extended_attention_mask
]
lm_output
=
self
.
language_model
(
*
args
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
():
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
if
self
.
post_process
:
_
,
pooled_output
=
lm_output
classification_output
=
self
.
classification_dropout
(
pooled_output
)
classification_logits
=
self
.
classification_head
(
classification_output
)
...
...
@@ -87,7 +97,7 @@ class ClassificationBase(MegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
state_dict_
[
self
.
_classification_head_key
]
\
=
self
.
classification_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
...
...
@@ -98,7 +108,7 @@ class ClassificationBase(MegatronModule):
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
if
self
.
_classification_head_key
in
state_dict
:
self
.
classification_head
.
load_state_dict
(
state_dict
[
self
.
_classification_head_key
],
strict
=
strict
)
...
...
@@ -106,55 +116,3 @@ class ClassificationBase(MegatronModule):
print_rank_last
(
'***WARNING*** could not find {} in the checkpoint, '
'initializing to random'
.
format
(
self
.
_classification_head_key
))
class
Classification
(
ClassificationBase
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
Classification
,
self
).
__init__
(
num_classes
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
Classification
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
ClassificationFirstStage
(
ClassificationBase
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
ClassificationFirstStage
,
self
).
__init__
(
num_classes
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
ClassificationFirstStage
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
ClassificationIntermediateStage
(
ClassificationBase
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
ClassificationIntermediateStage
,
self
).
__init__
(
num_classes
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
ClassificationIntermediateStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
class
ClassificationLastStage
(
ClassificationBase
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
ClassificationLastStage
,
self
).
__init__
(
num_classes
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
ClassificationLastStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
megatron/model/gpt_model.py
View file @
3b91262e
...
...
@@ -57,14 +57,20 @@ def post_language_model_processing(lm_output, labels, logit_weights,
return
loss
class
GPTModel
Base
(
MegatronModule
):
class
GPTModel
(
MegatronModule
):
"""GPT-2 Language model."""
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
):
super
(
GPTModelBase
,
self
).
__init__
()
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
GPTModel
,
self
).
__init__
()
args
=
get_args
()
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
...
...
@@ -73,24 +79,27 @@ class GPTModelBase(MegatronModule):
encoder_attn_mask_type
=
AttnMaskType
.
causal
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
def
forward
(
self
,
gpt_model_input
,
attention_mask
,
labels
=
None
,
def
set_input_tensor
(
self
,
input_tensor
):
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
kwargs
=
{
'layer_past'
:
layer_past
,
'get_key_value'
:
get_key_value
}
if
mpu
.
is_pipeline_first_stage
():
(
input_ids
,
position_ids
)
=
gpt_model_input
args
=
[
input_ids
,
position_ids
,
attention_mask
]
kwargs
[
'tokentype_ids'
]
=
tokentype_ids
else
:
args
=
[
gpt_model_input
,
attention_mask
]
lm_output
=
self
.
language_model
(
*
args
,
**
kwargs
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
labels
,
self
.
word_embeddings_weight
(),
...
...
@@ -109,7 +118,7 @@ class GPTModelBase(MegatronModule):
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
# Save word_embeddings.
if
mpu
.
is_pipeline_last_stage
()
and
not
mpu
.
is_pipeline_first_stage
()
:
if
self
.
post_process
and
not
self
.
pre_process
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
...
...
@@ -118,79 +127,9 @@ class GPTModelBase(MegatronModule):
"""Customized load."""
# Load word_embeddings.
if
mpu
.
is_pipeline_last_stage
()
and
not
mpu
.
is_pipeline_first_stage
()
:
if
self
.
post_process
and
not
self
.
pre_process
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
if
self
.
_language_model_key
in
state_dict
:
state_dict
=
state_dict
[
self
.
_language_model_key
]
self
.
language_model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
class
GPTModel
(
GPTModelBase
):
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
):
super
(
GPTModel
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
return
super
(
GPTModel
,
self
).
forward
(
(
input_ids
,
position_ids
),
attention_mask
,
labels
=
labels
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
class
GPTModelFirstStage
(
GPTModelBase
):
def
__init__
(
self
,
num_tokentypes
=
0
):
super
(
GPTModelFirstStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
return
super
(
GPTModelFirstStage
,
self
).
forward
(
(
input_ids
,
position_ids
),
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
class
GPTModelIntermediateStage
(
GPTModelBase
):
def
__init__
(
self
,
num_tokentypes
=
0
):
super
(
GPTModelIntermediateStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
return
super
(
GPTModelIntermediateStage
,
self
).
forward
(
hidden_state
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
class
GPTModelLastStage
(
GPTModelBase
):
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
):
super
(
GPTModelLastStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
def
forward
(
self
,
hidden_state
,
attention_mask
,
labels
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
return
super
(
GPTModelLastStage
,
self
).
forward
(
hidden_state
,
attention_mask
,
labels
=
labels
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
megatron/model/language_model.py
View file @
3b91262e
...
...
@@ -46,7 +46,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
num_tokentypes
,
add_pooler
,
encoder_attn_mask_type
,
init_method
=
None
,
scaled_init_method
=
None
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
):
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
pre_process
=
True
,
post_process
=
True
):
"""Build language model and return along with the key to save."""
args
=
get_args
()
...
...
@@ -58,26 +59,17 @@ def get_language_model(num_tokentypes, add_pooler,
args
.
num_layers
)
# Language model.
args
=
[
init_method
,
scaled_init_method
,
encoder_attn_mask_type
]
kwargs
=
{}
cls
=
None
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModel
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
kwargs
[
'add_decoder'
]
=
add_decoder
kwargs
[
'decoder_attn_mask_type'
]
=
decoder_attn_mask_type
kwargs
[
'add_pooler'
]
=
add_pooler
elif
mpu
.
is_pipeline_first_stage
()
and
not
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModelFirstStage
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
elif
not
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModelLastStage
kwargs
[
'add_pooler'
]
=
add_pooler
else
:
cls
=
TransformerLanguageModelIntermediateStage
# Language model.
language_model
=
cls
(
*
args
,
**
kwargs
)
language_model
=
TransformerLanguageModel
(
init_method
,
scaled_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
add_decoder
=
add_decoder
,
decoder_attn_mask_type
=
decoder_attn_mask_type
,
add_pooler
=
add_pooler
,
pre_process
=
pre_process
,
post_process
=
post_process
)
# key used for checkpoints.
language_model_key
=
'language_model'
...
...
@@ -263,7 +255,7 @@ class Embedding(MegatronModule):
'checkpoint but could not find it'
,
flush
=
True
)
class
TransformerLanguageModel
Base
(
MegatronModule
):
class
TransformerLanguageModel
(
MegatronModule
):
"""Transformer language model.
Arguments:
...
...
@@ -283,10 +275,14 @@ class TransformerLanguageModelBase(MegatronModule):
num_tokentypes
=
0
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_pooler
=
False
):
super
(
TransformerLanguageModelBase
,
self
).
__init__
()
add_pooler
=
False
,
pre_process
=
True
,
post_process
=
True
):
super
(
TransformerLanguageModel
,
self
).
__init__
()
args
=
get_args
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
...
...
@@ -296,7 +292,7 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
add_pooler
=
add_pooler
# Embeddings.
if
mpu
.
is_pipeline_first_stage
()
:
if
self
.
pre_process
:
self
.
embedding
=
Embedding
(
self
.
hidden_size
,
args
.
padded_vocab_size
,
args
.
max_position_embeddings
,
...
...
@@ -309,7 +305,10 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
encoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
)
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
_encoder_key
=
'encoder'
# Decoder
...
...
@@ -323,26 +322,28 @@ class TransformerLanguageModelBase(MegatronModule):
self_attn_mask_type
=
self
.
decoder_attn_mask_type
)
self
.
_decoder_key
=
'decoder'
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
# Pooler.
if
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
def
forward
(
self
,
enc_language_model_input
,
enc_attn_mask
,
dec_language_model_input
=
None
,
dec_attn_mask
=
None
,
def
set_input_tensor
(
self
,
input_tensor
):
self
.
encoder
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Embeddings.
if
mpu
.
is_pipeline_first_stage
():
(
input_ids
,
position_ids
)
=
enc_language_model_input
embedding_output
=
self
.
embedding
(
input_ids
,
position_ids
,
if
self
.
pre_process
:
embedding_output
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
tokentype_ids
=
tokentype_ids
)
encoder_input
=
embedding_output
else
:
encoder_input
=
enc_language_model_input
encoder_input
=
None
# encoder.
if
enc_hidden_states
is
None
:
...
...
@@ -353,7 +354,7 @@ class TransformerLanguageModelBase(MegatronModule):
else
:
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
if
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
)
...
...
@@ -362,13 +363,12 @@ class TransformerLanguageModelBase(MegatronModule):
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if
not
self
.
add_decoder
or
output_enc_hidden
:
if
self
.
add_pooler
and
mpu
.
is_pipeline_last_stage
()
:
if
self
.
add_pooler
and
self
.
post_process
:
return
encoder_output
,
pooled_output
else
:
return
encoder_output
# Decoder Embedding
(
dec_input_ids
,
dec_position_ids
)
=
dec_language_model_input
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
# decoder
...
...
@@ -379,7 +379,7 @@ class TransformerLanguageModelBase(MegatronModule):
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
)
if
self
.
add_pooler
and
mpu
.
is_pipeline_last_stage
()
:
if
self
.
add_pooler
and
self
.
post_process
:
return
decoder_output
,
encoder_output
,
pooled_output
else
:
return
decoder_output
,
encoder_output
...
...
@@ -389,14 +389,14 @@ class TransformerLanguageModelBase(MegatronModule):
"""For easy load."""
state_dict_
=
{}
if
mpu
.
is_pipeline_first_stage
()
:
if
self
.
pre_process
:
state_dict_
[
self
.
_embedding_key
]
\
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_encoder_key
]
\
=
self
.
encoder
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
if
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
\
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
...
...
@@ -412,7 +412,7 @@ class TransformerLanguageModelBase(MegatronModule):
"""Customized load."""
# Embedding.
if
mpu
.
is_pipeline_first_stage
()
:
if
self
.
pre_process
:
if
self
.
_embedding_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_embedding_key
]
else
:
...
...
@@ -448,7 +448,7 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
# pooler
if
self
.
add_pooler
:
assert
'pooler'
in
state_dict
,
\
...
...
@@ -461,124 +461,3 @@ class TransformerLanguageModelBase(MegatronModule):
'could not find data for pooler in the checkpoint'
self
.
decoder
.
load_state_dict
(
state_dict
[
self
.
_decoder_key
],
strict
=
strict
)
class
TransformerLanguageModel
(
TransformerLanguageModelBase
):
"""Transformer language model (see TransformerLanguageModelBase
for description of arguments).
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_decoder
=
False
,
add_pooler
=
False
):
super
(
TransformerLanguageModel
,
self
).
__init__
(
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
add_decoder
=
add_decoder
,
decoder_attn_mask_type
=
decoder_attn_mask_type
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
return
super
(
TransformerLanguageModel
,
self
).
forward
(
(
enc_input_ids
,
enc_position_ids
),
enc_attn_mask
,
dec_language_model_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_attn_mask
=
dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
,
enc_hidden_states
=
enc_hidden_states
,
output_enc_hidden
=
output_enc_hidden
)
class
TransformerLanguageModelFirstStage
(
TransformerLanguageModelBase
):
"""Transformer language model, first stage (see
TransformerLanguageModelBase for description of arguments).
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
):
super
(
TransformerLanguageModelFirstStage
,
self
).
__init__
(
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
return
super
(
TransformerLanguageModelFirstStage
,
self
).
forward
(
(
input_ids
,
position_ids
),
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
class
TransformerLanguageModelIntermediateStage
(
TransformerLanguageModelBase
):
"""Transformer language model, intermediate stage (see
TransformerLanguageModelBase for description of arguments).
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
):
super
(
TransformerLanguageModelIntermediateStage
,
self
).
__init__
(
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
return
super
(
TransformerLanguageModelIntermediateStage
,
self
).
forward
(
hidden_states
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
class
TransformerLanguageModelLastStage
(
TransformerLanguageModelBase
):
"""Transformer language model, final stage (see
TransformerLanguageModelBase for description of arguments).
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
add_pooler
=
False
):
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
return
super
(
TransformerLanguageModelLastStage
,
self
).
forward
(
hidden_states
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
,
)
megatron/model/multiple_choice.py
View file @
3b91262e
...
...
@@ -28,13 +28,18 @@ from megatron.model.utils import scaled_init_method_normal
from
.module
import
MegatronModule
class
MultipleChoice
Base
(
MegatronModule
):
class
MultipleChoice
(
MegatronModule
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoiceBase
,
self
).
__init__
(
share_word_embeddings
=
False
)
def
__init__
(
self
,
num_tokentypes
=
2
,
pre_process
=
True
,
post_process
=
True
):
super
(
MultipleChoice
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
...
...
@@ -42,15 +47,20 @@ class MultipleChoiceBase(MegatronModule):
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
# Multi-choice head.
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
self
.
multichoice_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
multichoice_head
=
get_linear_layer
(
args
.
hidden_size
,
1
,
init_method
)
self
.
_multichoice_head_key
=
'multichoice_head'
def
set_input_tensor
(
self
,
input_tensor
)
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
model_input
,
attention_mask
,
tokentype_ids
=
None
):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
...
...
@@ -64,22 +74,21 @@ class MultipleChoiceBase(MegatronModule):
attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
kwargs
=
{}
if
mpu
.
is_pipeline_first_stage
():
input_ids
=
model_input
# Do the same as attention_mask for input_ids, tokentype_ids
assert
len
(
input_ids
.
shape
)
==
3
assert
len
(
tokentype_ids
.
shape
)
==
3
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
tokentype_ids
=
tokentype_ids
.
view
(
-
1
,
tokentype_ids
.
size
(
-
1
))
position_ids
=
bert_position_ids
(
input_ids
)
args
=
[
input_ids
,
position_ids
,
extended_attention_mask
]
kwargs
[
'tokentype_ids'
]
=
tokentype_ids
else
:
args
=
[
model_input
,
extended_attention_mask
]
lm_output
=
self
.
language_model
(
*
args
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
():
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
if
self
.
post_process
:
_
,
pooled_output
=
lm_output
multichoice_output
=
self
.
multichoice_dropout
(
pooled_output
)
multichoice_logits
=
self
.
multichoice_head
(
multichoice_output
)
...
...
@@ -99,7 +108,7 @@ class MultipleChoiceBase(MegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
state_dict_
[
self
.
_multichoice_head_key
]
\
=
self
.
multichoice_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
...
...
@@ -110,7 +119,7 @@ class MultipleChoiceBase(MegatronModule):
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
if
self
.
_multichoice_head_key
in
state_dict
:
self
.
multichoice_head
.
load_state_dict
(
state_dict
[
self
.
_multichoice_head_key
],
strict
=
strict
)
...
...
@@ -119,53 +128,3 @@ class MultipleChoiceBase(MegatronModule):
'initializing to random'
.
format
(
self
.
_multichoice_head_key
))
class
MultipleChoice
(
MultipleChoiceBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoice
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
MultipleChoice
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
MultipleChoiceFirstStage
(
MultipleChoiceBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoiceFirstStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
MultipleChoiceFirstStage
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
MultipleChoiceIntermediateStage
(
MultipleChoiceBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoiceIntermediateStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
MultipleChoiceIntermediateStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
class
MultipleChoiceLastStage
(
MultipleChoiceBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoiceLastStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
MultipleChoiceLastStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
megatron/model/transformer.py
View file @
3b91262e
...
...
@@ -532,12 +532,16 @@ class ParallelTransformer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
):
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
# Store activation checkpoiting flag.
self
.
checkpoint_activations
=
args
.
checkpoint_activations
...
...
@@ -580,7 +584,7 @@ class ParallelTransformer(MegatronModule):
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
...
...
@@ -615,6 +619,9 @@ class ParallelTransformer(MegatronModule):
return
hidden_states
def
set_input_tensor
(
self
,
input_tensor
):
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
):
...
...
@@ -628,7 +635,7 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with '
\
'activation checkpointing'
if
mpu
.
is_pipeline_first_stage
()
:
if
self
.
pre_process
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
...
...
@@ -636,6 +643,8 @@ class ParallelTransformer(MegatronModule):
# Otherwise, leave it as is.
else
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
else
:
hidden_states
=
self
.
input_tensor
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
...
...
@@ -664,7 +673,7 @@ class ParallelTransformer(MegatronModule):
presents
.
append
(
present
)
# Final layer norm.
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
...
...
megatron/schedules.py
View file @
3b91262e
...
...
@@ -34,8 +34,10 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
timers
=
get_timers
()
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
model
.
module
.
module
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
...
...
megatron/training.py
View file @
3b91262e
...
...
@@ -196,7 +196,25 @@ def get_model(model_provider_func):
args
=
get_args
()
# Build model on cpu.
model
=
model_provider_func
()
pre_process
=
mpu
.
is_pipeline_first_stage
()
post_process
=
mpu
.
is_pipeline_last_stage
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
m
=
model_provider_func
(
pre_process
=
pre_process
,
post_process
=
post_process
)
model
.
append
(
m
)
else
:
model
=
model_provider_func
(
pre_process
=
pre_process
,
post_process
=
post_process
)
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
...
...
pretrain_bert.py
View file @
3b91262e
...
...
@@ -17,56 +17,30 @@
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
(
BertModel
,
BertModelFirstStage
,
BertModelIntermediateStage
,
BertModelLastStage
)
from
megatron.model
import
BertModel
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building BERT model ...'
)
args
=
get_args
()
num_tokentypes
=
2
if
args
.
bert_binary_head
else
0
def
model_provider_pipelined
():
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
model
=
BertModelFirstStage
(
num_tokentypes
=
num_tokentypes
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
BertModelLastStage
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
args
.
bert_binary_head
,
parallel_output
=
True
)
else
:
model
=
BertModelIntermediateStage
(
num_tokentypes
=
num_tokentypes
)
return
model
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
.
append
(
model_provider_pipelined
())
else
:
model
=
model_provider_pipelined
()
else
:
model
=
BertModel
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
args
.
bert_binary_head
,
parallel_output
=
True
)
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
...
...
@@ -96,36 +70,7 @@ def get_batch(data_iterator):
return
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
def
forward_step
(
data_iterator
,
model
,
input_tensor
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
timers
(
'batch-generator'
).
start
()
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
=
get_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
if
not
args
.
bert_binary_head
:
types
=
None
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
assert
input_tensor
is
None
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
,
lm_labels
=
lm_labels
)
else
:
output_tensor
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
)
elif
mpu
.
is_pipeline_last_stage
():
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
padding_mask
,
lm_labels
=
lm_labels
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
padding_mask
)
if
mpu
.
is_pipeline_last_stage
():
def
loss_func
(
loss_mask
,
sentence_order
,
output_tensor
):
lm_loss_
,
sop_logits
=
output_tensor
lm_loss_
=
lm_loss_
.
float
()
...
...
@@ -150,7 +95,26 @@ def forward_step(data_iterator, model, input_tensor):
[
lm_loss
])
return
loss
,
{
'lm loss'
:
averaged_losses
[
0
]}
return
output_tensor
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
timers
(
'batch-generator'
).
start
()
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
=
get_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
if
not
args
.
bert_binary_head
:
types
=
None
# Forward pass through the model.
output_tensor
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
,
lm_labels
=
lm_labels
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
,
sentence_order
)
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
...
...
pretrain_gpt.py
View file @
3b91262e
...
...
@@ -16,50 +16,28 @@
"""Pretrain GPT"""
import
torch
from
functools
import
partial
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
(
GPTModel
,
GPTModelFirstStage
,
GPTModelIntermediateStage
,
GPTModelLastStage
)
from
megatron.model
import
GPTModel
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
def
model_provider_pipelined
():
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
model
=
GPTModelFirstStage
(
num_tokentypes
=
0
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
GPTModelLastStage
(
num_tokentypes
=
0
,
parallel_output
=
True
)
else
:
model
=
GPTModelIntermediateStage
(
num_tokentypes
=
0
)
return
model
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
.
append
(
model_provider_pipelined
())
else
:
model
=
model_provider_pipelined
()
else
:
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
...
...
@@ -94,8 +72,18 @@ def get_batch(data_iterator):
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
def
loss_func
(
loss_mask
,
output_tensor
):
losses
=
output_tensor
.
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
forward_step
(
data_iterator
,
model
,
input_tensor
):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -106,31 +94,10 @@ def forward_step(data_iterator, model, input_tensor):
data_iterator
)
timers
(
'batch-generator'
).
stop
()
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
assert
input_tensor
is
None
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
)
elif
mpu
.
is_pipeline_last_stage
():
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
,
labels
=
labels
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
)
if
mpu
.
is_pipeline_last_stage
():
losses
=
output_tensor
.
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
return
output_tensor
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment