Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3fc035d7
Commit
3fc035d7
authored
Apr 03, 2021
by
Mohammad Shoeybi
Browse files
Merge branch 'pipeline_refactor' into 'main'
Pipeline refactor See merge request ADLR/megatron-lm!254
parents
f2d64c00
e270f68a
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
479 additions
and
837 deletions
+479
-837
megatron/model/__init__.py
megatron/model/__init__.py
+3
-11
megatron/model/bert_model.py
megatron/model/bert_model.py
+35
-84
megatron/model/classification.py
megatron/model/classification.py
+30
-71
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+26
-86
megatron/model/language_model.py
megatron/model/language_model.py
+42
-162
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+33
-74
megatron/model/transformer.py
megatron/model/transformer.py
+24
-6
megatron/schedules.py
megatron/schedules.py
+19
-1
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+18
-41
megatron/training.py
megatron/training.py
+26
-8
pretrain_bert.py
pretrain_bert.py
+40
-76
pretrain_gpt.py
pretrain_gpt.py
+23
-56
tasks/eval_utils.py
tasks/eval_utils.py
+74
-51
tasks/finetune_utils.py
tasks/finetune_utils.py
+33
-25
tasks/glue/finetune.py
tasks/glue/finetune.py
+4
-15
tasks/main.py
tasks/main.py
+5
-0
tasks/race/data.py
tasks/race/data.py
+2
-0
tasks/race/finetune.py
tasks/race/finetune.py
+5
-12
tasks/zeroshot_gpt/evaluate.py
tasks/zeroshot_gpt/evaluate.py
+24
-39
tools/generate_samples_gpt.py
tools/generate_samples_gpt.py
+13
-19
No files found.
megatron/model/__init__.py
View file @
3fc035d7
...
@@ -15,16 +15,8 @@
...
@@ -15,16 +15,8 @@
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.distributed
import
*
from
.distributed
import
DistributedDataParallel
from
.bert_model
import
(
BertModel
,
from
.bert_model
import
BertModel
BertModelFirstStage
,
from
.gpt_model
import
GPTModel
BertModelIntermediateStage
,
BertModelLastStage
)
from
.gpt_model
import
(
GPTModel
,
GPTModelFirstStage
,
GPTModelIntermediateStage
,
GPTModelLastStage
)
from
.language_model
import
get_language_model
from
.language_model
import
get_language_model
from
.module
import
Float16Module
from
.module
import
Float16Module
megatron/model/bert_model.py
View file @
3fc035d7
...
@@ -121,17 +121,23 @@ def post_language_model_processing(lm_output, pooled_output,
...
@@ -121,17 +121,23 @@ def post_language_model_processing(lm_output, pooled_output,
return
lm_loss
,
binary_logits
return
lm_loss
,
binary_logits
class
BertModel
Base
(
MegatronModule
):
class
BertModel
(
MegatronModule
):
"""Bert Language model."""
"""Bert Language model."""
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
def
__init__
(
self
,
parallel_output
=
True
):
num_tokentypes
=
2
,
super
(
BertModelBase
,
self
).
__init__
()
add_binary_head
=
True
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
BertModel
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
add_binary_head
=
add_binary_head
self
.
add_binary_head
=
add_binary_head
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
init_method
=
init_method_normal
(
args
.
init_method_std
)
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
...
@@ -142,10 +148,12 @@ class BertModelBase(MegatronModule):
...
@@ -142,10 +148,12 @@ class BertModelBase(MegatronModule):
add_pooler
=
self
.
add_binary_head
,
add_pooler
=
self
.
add_binary_head
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
self
.
initialize_word_embeddings
(
init_method_normal
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
self
.
lm_head
=
BertLMHead
(
self
.
lm_head
=
BertLMHead
(
self
.
word_embeddings_weight
().
size
(
0
),
self
.
word_embeddings_weight
().
size
(
0
),
args
.
hidden_size
,
init_method
,
args
.
layernorm_epsilon
,
parallel_output
)
args
.
hidden_size
,
init_method
,
args
.
layernorm_epsilon
,
parallel_output
)
...
@@ -156,26 +164,30 @@ class BertModelBase(MegatronModule):
...
@@ -156,26 +164,30 @@ class BertModelBase(MegatronModule):
init_method
)
init_method
)
self
.
_binary_head_key
=
'binary_head'
self
.
_binary_head_key
=
'binary_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
bert_model_input
,
attention_mask
,
def
forward
(
self
,
bert_model_input
,
attention_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
):
tokentype_ids
=
None
,
lm_labels
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
input_ids
=
bert_model_input
position_ids
=
bert_position_ids
(
input_ids
)
kwargs
=
{}
lm_output
=
self
.
language_model
(
if
mpu
.
is_pipeline_first_stage
():
input_ids
,
input_ids
=
bert_model_input
position_ids
,
position_ids
=
bert_position_ids
(
input_ids
)
extended_attention_mask
,
args
=
[
input_ids
,
position_ids
,
extended_attention_mask
]
tokentype_ids
=
tokentype_ids
kwargs
[
'tokentype_ids'
]
=
tokentype_ids
)
else
:
args
=
[
bert_model_input
,
extended_attention_mask
]
if
self
.
post_process
and
self
.
add_binary_head
:
lm_output
=
self
.
language_model
(
*
args
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_binary_head
:
lm_output
,
pooled_output
=
lm_output
lm_output
,
pooled_output
=
lm_output
else
:
else
:
pooled_output
=
None
pooled_output
=
None
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
pooled_output
,
return
post_language_model_processing
(
lm_output
,
pooled_output
,
self
.
lm_head
,
self
.
binary_head
,
self
.
lm_head
,
self
.
binary_head
,
lm_labels
,
lm_labels
,
...
@@ -194,15 +206,15 @@ class BertModelBase(MegatronModule):
...
@@ -194,15 +206,15 @@ class BertModelBase(MegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
state_dict_
[
self
.
_lm_head_key
]
\
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_binary_head
:
if
self
.
post_process
and
self
.
add_binary_head
:
state_dict_
[
self
.
_binary_head_key
]
\
state_dict_
[
self
.
_binary_head_key
]
\
=
self
.
binary_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
=
self
.
binary_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
# Save word_embeddings.
# Save word_embeddings.
if
mpu
.
is_pipeline_last_stage
()
and
not
mpu
.
is_pipeline_first_stage
()
:
if
self
.
post_process
and
not
self
.
pre_process
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
...
@@ -212,74 +224,13 @@ class BertModelBase(MegatronModule):
...
@@ -212,74 +224,13 @@ class BertModelBase(MegatronModule):
self
.
language_model
.
load_state_dict
(
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
self
.
lm_head
.
load_state_dict
(
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_binary_head
:
if
self
.
post_process
and
self
.
add_binary_head
:
self
.
binary_head
.
load_state_dict
(
self
.
binary_head
.
load_state_dict
(
state_dict
[
self
.
_binary_head_key
],
strict
=
strict
)
state_dict
[
self
.
_binary_head_key
],
strict
=
strict
)
# Load word_embeddings.
# Load word_embeddings.
if
mpu
.
is_pipeline_last_stage
()
and
not
mpu
.
is_pipeline_first_stage
()
:
if
self
.
post_process
and
not
self
.
pre_process
:
self
.
word_embeddings
.
load_state_dict
(
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
class
BertModel
(
BertModelBase
):
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
):
super
(
BertModel
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
add_binary_head
,
parallel_output
=
parallel_output
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
):
return
super
(
BertModel
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
lm_labels
=
lm_labels
)
class
BertModelFirstStage
(
BertModelBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
BertModelFirstStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
BertModelFirstStage
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
BertModelIntermediateStage
(
BertModelBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
BertModelIntermediateStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
BertModelIntermediateStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
class
BertModelLastStage
(
BertModelBase
):
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
):
super
(
BertModelLastStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
add_binary_head
,
parallel_output
=
parallel_output
)
def
forward
(
self
,
hidden_state
,
attention_mask
,
lm_labels
=
None
):
return
super
(
BertModelLastStage
,
self
).
forward
(
hidden_state
,
attention_mask
,
lm_labels
=
lm_labels
)
megatron/model/classification.py
View file @
3fc035d7
...
@@ -28,13 +28,19 @@ from megatron.model.utils import scaled_init_method_normal
...
@@ -28,13 +28,19 @@ from megatron.model.utils import scaled_init_method_normal
from
.module
import
MegatronModule
from
.module
import
MegatronModule
class
ClassificationBase
(
MegatronModule
):
class
Classification
(
MegatronModule
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
def
__init__
(
self
,
super
(
ClassificationBase
,
self
).
__init__
(
share_word_embeddings
=
False
)
num_classes
,
num_tokentypes
=
2
,
pre_process
=
True
,
post_process
=
True
):
super
(
Classification
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
args
=
get_args
()
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
init_method
=
init_method_normal
(
args
.
init_method_std
)
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
...
@@ -43,31 +49,36 @@ class ClassificationBase(MegatronModule):
...
@@ -43,31 +49,36 @@ class ClassificationBase(MegatronModule):
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
# Multi-choice head.
# Multi-choice head.
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
self
.
classification_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
classification_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
classification_head
=
get_linear_layer
(
args
.
hidden_size
,
self
.
classification_head
=
get_linear_layer
(
args
.
hidden_size
,
self
.
num_classes
,
self
.
num_classes
,
init_method
)
init_method
)
self
.
_classification_head_key
=
'classification_head'
self
.
_classification_head_key
=
'classification_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
model_input
,
attention_mask
,
tokentype_ids
=
None
):
def
forward
(
self
,
model_input
,
attention_mask
,
tokentype_ids
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
input_ids
=
model_input
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
kwargs
=
{}
if
self
.
post_process
:
if
mpu
.
is_pipeline_first_stage
():
input_ids
=
model_input
position_ids
=
bert_position_ids
(
input_ids
)
args
=
[
input_ids
,
position_ids
,
extended_attention_mask
]
kwargs
[
'tokentype_ids'
]
=
tokentype_ids
else
:
args
=
[
model_input
,
extended_attention_mask
]
lm_output
=
self
.
language_model
(
*
args
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
():
_
,
pooled_output
=
lm_output
_
,
pooled_output
=
lm_output
classification_output
=
self
.
classification_dropout
(
pooled_output
)
classification_output
=
self
.
classification_dropout
(
pooled_output
)
classification_logits
=
self
.
classification_head
(
classification_output
)
classification_logits
=
self
.
classification_head
(
classification_output
)
...
@@ -87,7 +98,7 @@ class ClassificationBase(MegatronModule):
...
@@ -87,7 +98,7 @@ class ClassificationBase(MegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
state_dict_
[
self
.
_classification_head_key
]
\
state_dict_
[
self
.
_classification_head_key
]
\
=
self
.
classification_head
.
state_dict
(
=
self
.
classification_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
...
@@ -98,7 +109,7 @@ class ClassificationBase(MegatronModule):
...
@@ -98,7 +109,7 @@ class ClassificationBase(MegatronModule):
self
.
language_model
.
load_state_dict
(
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
if
self
.
_classification_head_key
in
state_dict
:
if
self
.
_classification_head_key
in
state_dict
:
self
.
classification_head
.
load_state_dict
(
self
.
classification_head
.
load_state_dict
(
state_dict
[
self
.
_classification_head_key
],
strict
=
strict
)
state_dict
[
self
.
_classification_head_key
],
strict
=
strict
)
...
@@ -106,55 +117,3 @@ class ClassificationBase(MegatronModule):
...
@@ -106,55 +117,3 @@ class ClassificationBase(MegatronModule):
print_rank_last
(
'***WARNING*** could not find {} in the checkpoint, '
print_rank_last
(
'***WARNING*** could not find {} in the checkpoint, '
'initializing to random'
.
format
(
'initializing to random'
.
format
(
self
.
_classification_head_key
))
self
.
_classification_head_key
))
class
Classification
(
ClassificationBase
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
Classification
,
self
).
__init__
(
num_classes
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
Classification
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
ClassificationFirstStage
(
ClassificationBase
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
ClassificationFirstStage
,
self
).
__init__
(
num_classes
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
ClassificationFirstStage
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
ClassificationIntermediateStage
(
ClassificationBase
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
ClassificationIntermediateStage
,
self
).
__init__
(
num_classes
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
ClassificationIntermediateStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
class
ClassificationLastStage
(
ClassificationBase
):
def
__init__
(
self
,
num_classes
,
num_tokentypes
=
2
):
super
(
ClassificationLastStage
,
self
).
__init__
(
num_classes
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
ClassificationLastStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
megatron/model/gpt_model.py
View file @
3fc035d7
...
@@ -57,14 +57,20 @@ def post_language_model_processing(lm_output, labels, logit_weights,
...
@@ -57,14 +57,20 @@ def post_language_model_processing(lm_output, labels, logit_weights,
return
loss
return
loss
class
GPTModel
Base
(
MegatronModule
):
class
GPTModel
(
MegatronModule
):
"""GPT-2 Language model."""
"""GPT-2 Language model."""
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
):
def
__init__
(
self
,
super
(
GPTModelBase
,
self
).
__init__
()
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
GPTModel
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
...
@@ -73,24 +79,28 @@ class GPTModelBase(MegatronModule):
...
@@ -73,24 +79,28 @@ class GPTModelBase(MegatronModule):
encoder_attn_mask_type
=
AttnMaskType
.
causal
,
encoder_attn_mask_type
=
AttnMaskType
.
causal
,
init_method
=
init_method_normal
(
args
.
init_method_std
),
init_method
=
init_method_normal
(
args
.
init_method_std
),
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
self
.
initialize_word_embeddings
(
init_method_normal
)
def
forward
(
self
,
gpt_model_input
,
attention_mask
,
labels
=
None
,
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
forward_method_parallel_output
=
None
):
kwargs
=
{
'layer_past'
:
layer_past
,
'get_key_value'
:
get_key_value
}
lm_output
=
self
.
language_model
(
if
mpu
.
is_pipeline_first_stage
():
input_ids
,
(
input_ids
,
position_ids
)
=
gpt_model_input
position_ids
,
args
=
[
input_ids
,
position_ids
,
attention_mask
]
attention_mask
,
kwargs
[
'tokentype_ids'
]
=
tokentype_ids
layer_past
=
layer_past
,
else
:
get_key_value
=
get_key_value
)
args
=
[
gpt_model_input
,
attention_mask
]
lm_output
=
self
.
language_model
(
*
args
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
return
post_language_model_processing
(
return
post_language_model_processing
(
lm_output
,
labels
,
lm_output
,
labels
,
self
.
word_embeddings_weight
(),
self
.
word_embeddings_weight
(),
...
@@ -109,7 +119,7 @@ class GPTModelBase(MegatronModule):
...
@@ -109,7 +119,7 @@ class GPTModelBase(MegatronModule):
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
# Save word_embeddings.
# Save word_embeddings.
if
mpu
.
is_pipeline_last_stage
()
and
not
mpu
.
is_pipeline_first_stage
()
:
if
self
.
post_process
and
not
self
.
pre_process
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
return
state_dict_
...
@@ -118,79 +128,9 @@ class GPTModelBase(MegatronModule):
...
@@ -118,79 +128,9 @@ class GPTModelBase(MegatronModule):
"""Customized load."""
"""Customized load."""
# Load word_embeddings.
# Load word_embeddings.
if
mpu
.
is_pipeline_last_stage
()
and
not
mpu
.
is_pipeline_first_stage
()
:
if
self
.
post_process
and
not
self
.
pre_process
:
self
.
word_embeddings
.
load_state_dict
(
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
if
self
.
_language_model_key
in
state_dict
:
if
self
.
_language_model_key
in
state_dict
:
state_dict
=
state_dict
[
self
.
_language_model_key
]
state_dict
=
state_dict
[
self
.
_language_model_key
]
self
.
language_model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
self
.
language_model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
class
GPTModel
(
GPTModelBase
):
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
):
super
(
GPTModel
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
return
super
(
GPTModel
,
self
).
forward
(
(
input_ids
,
position_ids
),
attention_mask
,
labels
=
labels
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
class
GPTModelFirstStage
(
GPTModelBase
):
def
__init__
(
self
,
num_tokentypes
=
0
):
super
(
GPTModelFirstStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
return
super
(
GPTModelFirstStage
,
self
).
forward
(
(
input_ids
,
position_ids
),
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
class
GPTModelIntermediateStage
(
GPTModelBase
):
def
__init__
(
self
,
num_tokentypes
=
0
):
super
(
GPTModelIntermediateStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
return
super
(
GPTModelIntermediateStage
,
self
).
forward
(
hidden_state
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
class
GPTModelLastStage
(
GPTModelBase
):
def
__init__
(
self
,
num_tokentypes
=
0
,
parallel_output
=
True
):
super
(
GPTModelLastStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
,
parallel_output
=
parallel_output
)
def
forward
(
self
,
hidden_state
,
attention_mask
,
labels
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
forward_method_parallel_output
=
None
):
return
super
(
GPTModelLastStage
,
self
).
forward
(
hidden_state
,
attention_mask
,
labels
=
labels
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
megatron/model/language_model.py
View file @
3fc035d7
...
@@ -46,7 +46,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -46,7 +46,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def
get_language_model
(
num_tokentypes
,
add_pooler
,
def
get_language_model
(
num_tokentypes
,
add_pooler
,
encoder_attn_mask_type
,
init_method
=
None
,
encoder_attn_mask_type
,
init_method
=
None
,
scaled_init_method
=
None
,
add_decoder
=
False
,
scaled_init_method
=
None
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
):
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
pre_process
=
True
,
post_process
=
True
):
"""Build language model and return along with the key to save."""
"""Build language model and return along with the key to save."""
args
=
get_args
()
args
=
get_args
()
...
@@ -58,26 +59,17 @@ def get_language_model(num_tokentypes, add_pooler,
...
@@ -58,26 +59,17 @@ def get_language_model(num_tokentypes, add_pooler,
args
.
num_layers
)
args
.
num_layers
)
# Language model.
# Language model.
args
=
[
init_method
,
scaled_init_method
,
encoder_attn_mask_type
]
language_model
=
TransformerLanguageModel
(
kwargs
=
{}
init_method
,
cls
=
None
scaled_init_method
,
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
encoder_attn_mask_type
,
cls
=
TransformerLanguageModel
num_tokentypes
=
num_tokentypes
,
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
add_decoder
=
add_decoder
,
kwargs
[
'add_decoder'
]
=
add_decoder
decoder_attn_mask_type
=
decoder_attn_mask_type
,
kwargs
[
'decoder_attn_mask_type'
]
=
decoder_attn_mask_type
add_pooler
=
add_pooler
,
kwargs
[
'add_pooler'
]
=
add_pooler
pre_process
=
pre_process
,
elif
mpu
.
is_pipeline_first_stage
()
and
not
mpu
.
is_pipeline_last_stage
():
post_process
=
post_process
cls
=
TransformerLanguageModelFirstStage
)
kwargs
[
'num_tokentypes'
]
=
num_tokentypes
elif
not
mpu
.
is_pipeline_first_stage
()
and
mpu
.
is_pipeline_last_stage
():
cls
=
TransformerLanguageModelLastStage
kwargs
[
'add_pooler'
]
=
add_pooler
else
:
cls
=
TransformerLanguageModelIntermediateStage
# Language model.
language_model
=
cls
(
*
args
,
**
kwargs
)
# key used for checkpoints.
# key used for checkpoints.
language_model_key
=
'language_model'
language_model_key
=
'language_model'
...
@@ -263,7 +255,7 @@ class Embedding(MegatronModule):
...
@@ -263,7 +255,7 @@ class Embedding(MegatronModule):
'checkpoint but could not find it'
,
flush
=
True
)
'checkpoint but could not find it'
,
flush
=
True
)
class
TransformerLanguageModel
Base
(
MegatronModule
):
class
TransformerLanguageModel
(
MegatronModule
):
"""Transformer language model.
"""Transformer language model.
Arguments:
Arguments:
...
@@ -283,10 +275,14 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -283,10 +275,14 @@ class TransformerLanguageModelBase(MegatronModule):
num_tokentypes
=
0
,
num_tokentypes
=
0
,
add_decoder
=
False
,
add_decoder
=
False
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_pooler
=
False
):
add_pooler
=
False
,
super
(
TransformerLanguageModelBase
,
self
).
__init__
()
pre_process
=
True
,
post_process
=
True
):
super
(
TransformerLanguageModel
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
args
.
hidden_size
self
.
num_tokentypes
=
num_tokentypes
self
.
num_tokentypes
=
num_tokentypes
self
.
init_method
=
init_method
self
.
init_method
=
init_method
...
@@ -296,7 +292,7 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -296,7 +292,7 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
add_pooler
=
add_pooler
self
.
add_pooler
=
add_pooler
# Embeddings.
# Embeddings.
if
mpu
.
is_pipeline_first_stage
()
:
if
self
.
pre_process
:
self
.
embedding
=
Embedding
(
self
.
hidden_size
,
self
.
embedding
=
Embedding
(
self
.
hidden_size
,
args
.
padded_vocab_size
,
args
.
padded_vocab_size
,
args
.
max_position_embeddings
,
args
.
max_position_embeddings
,
...
@@ -309,7 +305,10 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -309,7 +305,10 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
encoder
=
ParallelTransformer
(
self
.
encoder
=
ParallelTransformer
(
self
.
init_method
,
self
.
init_method
,
output_layer_init_method
,
output_layer_init_method
,
self_attn_mask_type
=
self
.
encoder_attn_mask_type
)
self_attn_mask_type
=
self
.
encoder_attn_mask_type
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
_encoder_key
=
'encoder'
self
.
_encoder_key
=
'encoder'
# Decoder
# Decoder
...
@@ -323,26 +322,29 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -323,26 +322,29 @@ class TransformerLanguageModelBase(MegatronModule):
self_attn_mask_type
=
self
.
decoder_attn_mask_type
)
self_attn_mask_type
=
self
.
decoder_attn_mask_type
)
self
.
_decoder_key
=
'decoder'
self
.
_decoder_key
=
'decoder'
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
# Pooler.
# Pooler.
if
self
.
add_pooler
:
if
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
self
.
_pooler_key
=
'pooler'
def
forward
(
self
,
enc_language_model_input
,
enc_attn_mask
,
def
set_input_tensor
(
self
,
input_tensor
):
dec_language_model_input
=
None
,
dec_attn_mask
=
None
,
""" See megatron.model.transformer.set_input_tensor()"""
self
.
encoder
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Embeddings.
# Embeddings.
if
mpu
.
is_pipeline_first_stage
():
if
self
.
pre_process
:
(
input_ids
,
position_ids
)
=
enc_language_model_input
embedding_output
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
embedding_output
=
self
.
embedding
(
input_ids
,
position_ids
,
tokentype_ids
=
tokentype_ids
)
tokentype_ids
=
tokentype_ids
)
encoder_input
=
embedding_output
encoder_input
=
embedding_output
else
:
else
:
encoder_input
=
enc_language_model_input
encoder_input
=
None
# encoder.
# encoder.
if
enc_hidden_states
is
None
:
if
enc_hidden_states
is
None
:
...
@@ -353,7 +355,7 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -353,7 +355,7 @@ class TransformerLanguageModelBase(MegatronModule):
else
:
else
:
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
if
self
.
add_pooler
:
if
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
encoder_output
,
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
)
pooling_sequence_index
)
...
@@ -362,13 +364,12 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -362,13 +364,12 @@ class TransformerLanguageModelBase(MegatronModule):
# output. For example, it is helpful to compute
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
# similarity between two sequences by average pooling
if
not
self
.
add_decoder
or
output_enc_hidden
:
if
not
self
.
add_decoder
or
output_enc_hidden
:
if
self
.
add_pooler
and
mpu
.
is_pipeline_last_stage
()
:
if
self
.
add_pooler
and
self
.
post_process
:
return
encoder_output
,
pooled_output
return
encoder_output
,
pooled_output
else
:
else
:
return
encoder_output
return
encoder_output
# Decoder Embedding
# Decoder Embedding
(
dec_input_ids
,
dec_position_ids
)
=
dec_language_model_input
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
dec_position_ids
)
# decoder
# decoder
...
@@ -379,7 +380,7 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -379,7 +380,7 @@ class TransformerLanguageModelBase(MegatronModule):
encoder_output
=
encoder_output
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
)
enc_dec_attn_mask
=
enc_dec_attn_mask
)
if
self
.
add_pooler
and
mpu
.
is_pipeline_last_stage
()
:
if
self
.
add_pooler
and
self
.
post_process
:
return
decoder_output
,
encoder_output
,
pooled_output
return
decoder_output
,
encoder_output
,
pooled_output
else
:
else
:
return
decoder_output
,
encoder_output
return
decoder_output
,
encoder_output
...
@@ -389,14 +390,14 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -389,14 +390,14 @@ class TransformerLanguageModelBase(MegatronModule):
"""For easy load."""
"""For easy load."""
state_dict_
=
{}
state_dict_
=
{}
if
mpu
.
is_pipeline_first_stage
()
:
if
self
.
pre_process
:
state_dict_
[
self
.
_embedding_key
]
\
state_dict_
[
self
.
_embedding_key
]
\
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
=
self
.
embedding
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_encoder_key
]
\
state_dict_
[
self
.
_encoder_key
]
\
=
self
.
encoder
.
state_dict_for_save_checkpoint
(
=
self
.
encoder
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
if
self
.
add_pooler
:
if
self
.
add_pooler
:
state_dict_
[
self
.
_pooler_key
]
\
state_dict_
[
self
.
_pooler_key
]
\
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
...
@@ -412,7 +413,7 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -412,7 +413,7 @@ class TransformerLanguageModelBase(MegatronModule):
"""Customized load."""
"""Customized load."""
# Embedding.
# Embedding.
if
mpu
.
is_pipeline_first_stage
()
:
if
self
.
pre_process
:
if
self
.
_embedding_key
in
state_dict
:
if
self
.
_embedding_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_embedding_key
]
state_dict_
=
state_dict
[
self
.
_embedding_key
]
else
:
else
:
...
@@ -448,7 +449,7 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -448,7 +449,7 @@ class TransformerLanguageModelBase(MegatronModule):
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
self
.
encoder
.
load_state_dict
(
state_dict_
,
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
# pooler
# pooler
if
self
.
add_pooler
:
if
self
.
add_pooler
:
assert
'pooler'
in
state_dict
,
\
assert
'pooler'
in
state_dict
,
\
...
@@ -461,124 +462,3 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -461,124 +462,3 @@ class TransformerLanguageModelBase(MegatronModule):
'could not find data for pooler in the checkpoint'
'could not find data for pooler in the checkpoint'
self
.
decoder
.
load_state_dict
(
state_dict
[
self
.
_decoder_key
],
self
.
decoder
.
load_state_dict
(
state_dict
[
self
.
_decoder_key
],
strict
=
strict
)
strict
=
strict
)
class
TransformerLanguageModel
(
TransformerLanguageModelBase
):
"""Transformer language model (see TransformerLanguageModelBase
for description of arguments).
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
,
decoder_attn_mask_type
=
AttnMaskType
.
causal
,
add_decoder
=
False
,
add_pooler
=
False
):
super
(
TransformerLanguageModel
,
self
).
__init__
(
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
,
add_decoder
=
add_decoder
,
decoder_attn_mask_type
=
decoder_attn_mask_type
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
return
super
(
TransformerLanguageModel
,
self
).
forward
(
(
enc_input_ids
,
enc_position_ids
),
enc_attn_mask
,
dec_language_model_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_attn_mask
=
dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
,
enc_hidden_states
=
enc_hidden_states
,
output_enc_hidden
=
output_enc_hidden
)
class
TransformerLanguageModelFirstStage
(
TransformerLanguageModelBase
):
"""Transformer language model, first stage (see
TransformerLanguageModelBase for description of arguments).
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
0
):
super
(
TransformerLanguageModelFirstStage
,
self
).
__init__
(
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
return
super
(
TransformerLanguageModelFirstStage
,
self
).
forward
(
(
input_ids
,
position_ids
),
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
class
TransformerLanguageModelIntermediateStage
(
TransformerLanguageModelBase
):
"""Transformer language model, intermediate stage (see
TransformerLanguageModelBase for description of arguments).
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
):
super
(
TransformerLanguageModelIntermediateStage
,
self
).
__init__
(
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
return
super
(
TransformerLanguageModelIntermediateStage
,
self
).
forward
(
hidden_states
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
class
TransformerLanguageModelLastStage
(
TransformerLanguageModelBase
):
"""Transformer language model, final stage (see
TransformerLanguageModelBase for description of arguments).
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
add_pooler
=
False
):
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
init_method
,
output_layer_init_method
,
encoder_attn_mask_type
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
return
super
(
TransformerLanguageModelLastStage
,
self
).
forward
(
hidden_states
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
,
)
megatron/model/multiple_choice.py
View file @
3fc035d7
...
@@ -28,13 +28,18 @@ from megatron.model.utils import scaled_init_method_normal
...
@@ -28,13 +28,18 @@ from megatron.model.utils import scaled_init_method_normal
from
.module
import
MegatronModule
from
.module
import
MegatronModule
class
MultipleChoice
Base
(
MegatronModule
):
class
MultipleChoice
(
MegatronModule
):
def
__init__
(
self
,
num_tokentypes
=
2
):
def
__init__
(
self
,
super
(
MultipleChoiceBase
,
self
).
__init__
(
share_word_embeddings
=
False
)
num_tokentypes
=
2
,
pre_process
=
True
,
post_process
=
True
):
super
(
MultipleChoice
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
args
=
get_args
()
init_method
=
init_method_normal
(
args
.
init_method_std
)
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
num_tokentypes
=
num_tokentypes
,
...
@@ -42,15 +47,21 @@ class MultipleChoiceBase(MegatronModule):
...
@@ -42,15 +47,21 @@ class MultipleChoiceBase(MegatronModule):
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
))
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
# Multi-choice head.
# Multi-choice head.
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
self
.
multichoice_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
multichoice_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
multichoice_head
=
get_linear_layer
(
args
.
hidden_size
,
1
,
self
.
multichoice_head
=
get_linear_layer
(
args
.
hidden_size
,
1
,
init_method
)
init_method
)
self
.
_multichoice_head_key
=
'multichoice_head'
self
.
_multichoice_head_key
=
'multichoice_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
model_input
,
attention_mask
,
tokentype_ids
=
None
):
def
forward
(
self
,
model_input
,
attention_mask
,
tokentype_ids
=
None
):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
# [batch, choices, sequence] --> [batch * choices, sequence] -->
...
@@ -64,22 +75,21 @@ class MultipleChoiceBase(MegatronModule):
...
@@ -64,22 +75,21 @@ class MultipleChoiceBase(MegatronModule):
attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
kwargs
=
{}
input_ids
=
model_input
if
mpu
.
is_pipeline_first_stage
():
# Do the same as attention_mask for input_ids, tokentype_ids
input_ids
=
model_input
assert
len
(
input_ids
.
shape
)
==
3
# Do the same as attention_mask for input_ids, tokentype_ids
assert
len
(
tokentype_ids
.
shape
)
==
3
assert
len
(
input_ids
.
shape
)
==
3
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
assert
len
(
tokentype_ids
.
shape
)
==
3
tokentype_ids
=
tokentype_ids
.
view
(
-
1
,
tokentype_ids
.
size
(
-
1
))
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
position_ids
=
bert_position_ids
(
input_ids
)
tokentype_ids
=
tokentype_ids
.
view
(
-
1
,
tokentype_ids
.
size
(
-
1
))
lm_output
=
self
.
language_model
(
position_ids
=
bert_position_ids
(
input_ids
)
input_ids
,
args
=
[
input_ids
,
position_ids
,
extended_attention_mask
]
position_ids
,
kwargs
[
'tokentype_ids'
]
=
tokentype_ids
extended_attention_mask
,
else
:
tokentype_ids
=
tokentype_ids
args
=
[
model_input
,
extended_attention_mask
]
)
lm_output
=
self
.
language_model
(
*
args
,
**
kwargs
)
if
self
.
post_process
:
if
mpu
.
is_pipeline_last_stage
():
_
,
pooled_output
=
lm_output
_
,
pooled_output
=
lm_output
multichoice_output
=
self
.
multichoice_dropout
(
pooled_output
)
multichoice_output
=
self
.
multichoice_dropout
(
pooled_output
)
multichoice_logits
=
self
.
multichoice_head
(
multichoice_output
)
multichoice_logits
=
self
.
multichoice_head
(
multichoice_output
)
...
@@ -99,7 +109,7 @@ class MultipleChoiceBase(MegatronModule):
...
@@ -99,7 +109,7 @@ class MultipleChoiceBase(MegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
state_dict_
[
self
.
_multichoice_head_key
]
\
state_dict_
[
self
.
_multichoice_head_key
]
\
=
self
.
multichoice_head
.
state_dict
(
=
self
.
multichoice_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
destination
,
prefix
,
keep_vars
)
...
@@ -110,7 +120,7 @@ class MultipleChoiceBase(MegatronModule):
...
@@ -110,7 +120,7 @@ class MultipleChoiceBase(MegatronModule):
self
.
language_model
.
load_state_dict
(
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
if
self
.
_multichoice_head_key
in
state_dict
:
if
self
.
_multichoice_head_key
in
state_dict
:
self
.
multichoice_head
.
load_state_dict
(
self
.
multichoice_head
.
load_state_dict
(
state_dict
[
self
.
_multichoice_head_key
],
strict
=
strict
)
state_dict
[
self
.
_multichoice_head_key
],
strict
=
strict
)
...
@@ -118,54 +128,3 @@ class MultipleChoiceBase(MegatronModule):
...
@@ -118,54 +128,3 @@ class MultipleChoiceBase(MegatronModule):
print_rank_last
(
'***WARNING*** could not find {} in the checkpoint, '
print_rank_last
(
'***WARNING*** could not find {} in the checkpoint, '
'initializing to random'
.
format
(
'initializing to random'
.
format
(
self
.
_multichoice_head_key
))
self
.
_multichoice_head_key
))
class
MultipleChoice
(
MultipleChoiceBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoice
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
MultipleChoice
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
MultipleChoiceFirstStage
(
MultipleChoiceBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoiceFirstStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
MultipleChoiceFirstStage
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
MultipleChoiceIntermediateStage
(
MultipleChoiceBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoiceIntermediateStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
MultipleChoiceIntermediateStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
class
MultipleChoiceLastStage
(
MultipleChoiceBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
MultipleChoiceLastStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
MultipleChoiceLastStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
megatron/model/transformer.py
View file @
3fc035d7
...
@@ -532,12 +532,16 @@ class ParallelTransformer(MegatronModule):
...
@@ -532,12 +532,16 @@ class ParallelTransformer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
bf16
=
args
.
bf16
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
checkpoint_activations
=
args
.
checkpoint_activations
self
.
checkpoint_activations
=
args
.
checkpoint_activations
...
@@ -572,15 +576,16 @@ class ParallelTransformer(MegatronModule):
...
@@ -572,15 +576,16 @@ class ParallelTransformer(MegatronModule):
# Stage 0: [0, 1] [4, 5]
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
# Stage 1: [2, 3] [6, 7]
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
args
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
)
+
\
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
else
:
# Each stage gets a contiguous set of layers.
# Each stage gets a contiguous set of layers.
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
# Final layer norm before output.
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
...
@@ -615,6 +620,16 @@ class ParallelTransformer(MegatronModule):
...
@@ -615,6 +620,16 @@ class ParallelTransformer(MegatronModule):
return
hidden_states
return
hidden_states
def
set_input_tensor
(
self
,
input_tensor
):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
):
get_key_value
=
False
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
):
...
@@ -628,7 +643,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -628,7 +643,7 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with '
\
'get_key_value does not work with '
\
'activation checkpointing'
'activation checkpointing'
if
mpu
.
is_pipeline_first_stage
()
:
if
self
.
pre_process
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
if
self
.
fp32_residual_connection
:
...
@@ -636,10 +651,13 @@ class ParallelTransformer(MegatronModule):
...
@@ -636,10 +651,13 @@ class ParallelTransformer(MegatronModule):
# Otherwise, leave it as is.
# Otherwise, leave it as is.
else
:
else
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
else
:
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
if
encoder_output
is
not
None
:
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
checkpoint_activations
:
if
self
.
checkpoint_activations
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
attention_mask
,
...
@@ -664,7 +682,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -664,7 +682,7 @@ class ParallelTransformer(MegatronModule):
presents
.
append
(
present
)
presents
.
append
(
present
)
# Final layer norm.
# Final layer norm.
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
# Reverting data format change [s b h] --> [b s h].
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
output
=
self
.
final_layernorm
(
hidden_states
)
...
...
megatron/schedules.py
View file @
3fc035d7
...
@@ -22,6 +22,20 @@ from megatron import get_num_microbatches
...
@@ -22,6 +22,20 @@ from megatron import get_num_microbatches
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
p2p_communication
from
megatron
import
p2p_communication
from
megatron.utils
import
unwrap_model
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
def
get_forward_backward_func
():
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
forward_backward_func
=
forward_backward_no_pipelining
return
forward_backward_func
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
...
@@ -34,8 +48,12 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
...
@@ -34,8 +48,12 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
timers
=
get_timers
()
timers
=
get_timers
()
timers
(
'forward-compute'
).
start
()
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
losses_reduced
.
append
(
loss_reduced
)
...
...
megatron/text_generation_utils.py
View file @
3fc035d7
...
@@ -26,9 +26,13 @@ import torch.nn.functional as F
...
@@ -26,9 +26,13 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.
training
import
communicate
from
megatron.
utils
import
get_ltor_masks_and_position_ids
,
unwrap_model
from
megatron.
utils
import
get_ltor_masks_and_position_ids
from
megatron.
p2p_communication
import
recv_forward
,
send_forward
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
def
get_batch
(
context_tokens
):
def
get_batch
(
context_tokens
):
"""Generate batch from context tokens."""
"""Generate batch from context tokens."""
...
@@ -395,55 +399,28 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
...
@@ -395,55 +399,28 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past
=
None
,
get_key_value
=
None
,
layer_past
=
None
,
get_key_value
=
None
,
forward_method_parallel_output
=
None
):
forward_method_parallel_output
=
None
):
# Hidden size changes when not using recompute, need to tell communicate
()
# Hidden size changes when not using recompute, need to tell
p2p_
communicate
# the correct size
#
functions
the correct size
args
=
get_args
()
args
=
get_args
()
orig_seq_length
=
args
.
seq_length
orig_seq_length
=
args
.
seq_length
args
.
seq_length
=
tokens
.
shape
[
1
]
args
.
seq_length
=
tokens
.
shape
[
1
]
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
=
recv_forward
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
input_tensor
=
None
# Forward pass through the model.
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
unwrapped_model
=
unwrap_model
(
assert
input_tensor
is
None
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
mpu
.
is_pipeline_last_stage
():
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
forward_method_parallel_output
=
forward_method_parallel_output
)
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
elif
mpu
.
is_pipeline_last_stage
():
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
forward_method_parallel_output
=
forward_method_parallel_output
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
if
get_key_value
:
output_tensor
,
layer_past
=
output_tensor
output_tensor
,
layer_past
=
output_tensor
if
not
mpu
.
is_pipeline_last_stage
():
send_forward
(
output_tensor
)
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
args
.
seq_length
=
orig_seq_length
args
.
seq_length
=
orig_seq_length
if
get_key_value
:
if
get_key_value
:
...
...
megatron/training.py
View file @
3fc035d7
...
@@ -61,10 +61,10 @@ def print_datetime(string):
...
@@ -61,10 +61,10 @@ def print_datetime(string):
print_rank_0
(
'['
+
string
+
'] datetime: {} '
.
format
(
time_str
))
print_rank_0
(
'['
+
string
+
'] datetime: {} '
.
format
(
time_str
))
def
pretrain
(
train_valid_test_dataset_provider
,
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
model_provider
,
forward_step_func
,
forward_step_func
,
extra_args_provider
=
None
,
extra_args_provider
=
None
,
args_defaults
=
{}):
args_defaults
=
{}):
"""Main training program.
"""Main training program.
...
@@ -196,7 +196,25 @@ def get_model(model_provider_func):
...
@@ -196,7 +196,25 @@ def get_model(model_provider_func):
args
=
get_args
()
args
=
get_args
()
# Build model on cpu.
# Build model on cpu.
model
=
model_provider_func
()
pre_process
=
mpu
.
is_pipeline_first_stage
()
post_process
=
mpu
.
is_pipeline_last_stage
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
this_model
=
model_provider_func
(
pre_process
=
pre_process
,
post_process
=
post_process
)
model
.
append
(
this_model
)
else
:
model
=
model_provider_func
(
pre_process
=
pre_process
,
post_process
=
post_process
)
if
not
isinstance
(
model
,
list
):
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
model
=
[
model
]
...
@@ -231,7 +249,7 @@ def get_model(model_provider_func):
...
@@ -231,7 +249,7 @@ def get_model(model_provider_func):
process_group
=
mpu
.
get_data_parallel_group
())
process_group
=
mpu
.
get_data_parallel_group
())
for
model_module
in
model
]
for
model_module
in
model
]
return
model
return
model
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
model
=
[
LocalDDP
(
model_module
,
model
=
[
LocalDDP
(
model_module
,
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
accumulate_allreduce_grads_in_fp32
,
...
@@ -651,16 +669,16 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -651,16 +669,16 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if
not
saved_checkpoint
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
lr_scheduler
)
print_datetime
(
'exiting program after {} minutes'
.
format
(
train_time
))
print_datetime
(
'exiting program after {} minutes'
.
format
(
train_time
))
sys
.
exit
()
sys
.
exit
()
# Exiting based on iterations
# Exiting based on iterations
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
not
saved_checkpoint
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
lr_scheduler
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
print_datetime
(
'exiting program at iteration {}'
.
format
(
iteration
))
print_datetime
(
'exiting program at iteration {}'
.
format
(
iteration
))
sys
.
exit
()
sys
.
exit
()
...
...
pretrain_bert.py
View file @
3fc035d7
...
@@ -17,56 +17,30 @@
...
@@ -17,56 +17,30 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.model
import
(
BertModel
,
from
megatron.model
import
BertModel
BertModelFirstStage
,
BertModelIntermediateStage
,
BertModelLastStage
)
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
print_rank_0
(
'building BERT model ...'
)
print_rank_0
(
'building BERT model ...'
)
args
=
get_args
()
args
=
get_args
()
num_tokentypes
=
2
if
args
.
bert_binary_head
else
0
num_tokentypes
=
2
if
args
.
bert_binary_head
else
0
def
model_provider_pipelined
():
model
=
BertModel
(
# Determine model based on position of stage in pipeline.
num_tokentypes
=
num_tokentypes
,
if
mpu
.
is_pipeline_first_stage
():
add_binary_head
=
args
.
bert_binary_head
,
model
=
BertModelFirstStage
(
parallel_output
=
True
,
num_tokentypes
=
num_tokentypes
)
pre_process
=
pre_process
,
elif
mpu
.
is_pipeline_last_stage
():
post_process
=
post_process
)
model
=
BertModelLastStage
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
args
.
bert_binary_head
,
parallel_output
=
True
)
else
:
model
=
BertModelIntermediateStage
(
num_tokentypes
=
num_tokentypes
)
return
model
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
.
append
(
model_provider_pipelined
())
else
:
model
=
model_provider_pipelined
()
else
:
model
=
BertModel
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
args
.
bert_binary_head
,
parallel_output
=
True
)
return
model
return
model
...
@@ -96,7 +70,33 @@ def get_batch(data_iterator):
...
@@ -96,7 +70,33 @@ def get_batch(data_iterator):
return
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
return
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
def
forward_step
(
data_iterator
,
model
,
input_tensor
):
def
loss_func
(
loss_mask
,
sentence_order
,
output_tensor
):
lm_loss_
,
sop_logits
=
output_tensor
lm_loss_
=
lm_loss_
.
float
()
loss_mask
=
loss_mask
.
float
()
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
if
sop_logits
is
not
None
:
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
float
(),
sentence_order
.
view
(
-
1
),
ignore_index
=-
1
)
sop_loss
=
sop_loss
.
float
()
loss
=
lm_loss
+
sop_loss
averaged_losses
=
average_losses_across_data_parallel_group
(
[
lm_loss
,
sop_loss
])
return
loss
,
{
'lm loss'
:
averaged_losses
[
0
],
'sop loss'
:
averaged_losses
[
1
]}
else
:
loss
=
lm_loss
averaged_losses
=
average_losses_across_data_parallel_group
(
[
lm_loss
])
return
loss
,
{
'lm loss'
:
averaged_losses
[
0
]}
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
"""Forward step."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -111,46 +111,10 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -111,46 +111,10 @@ def forward_step(data_iterator, model, input_tensor):
types
=
None
types
=
None
# Forward pass through the model.
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
output_tensor
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
,
assert
input_tensor
is
None
lm_labels
=
lm_labels
)
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
,
return
output_tensor
,
partial
(
loss_func
,
loss_mask
,
sentence_order
)
lm_labels
=
lm_labels
)
else
:
output_tensor
=
model
(
tokens
,
padding_mask
,
tokentype_ids
=
types
)
elif
mpu
.
is_pipeline_last_stage
():
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
padding_mask
,
lm_labels
=
lm_labels
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
padding_mask
)
if
mpu
.
is_pipeline_last_stage
():
lm_loss_
,
sop_logits
=
output_tensor
lm_loss_
=
lm_loss_
.
float
()
loss_mask
=
loss_mask
.
float
()
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
if
sop_logits
is
not
None
:
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
float
(),
sentence_order
.
view
(
-
1
),
ignore_index
=-
1
)
sop_loss
=
sop_loss
.
float
()
loss
=
lm_loss
+
sop_loss
averaged_losses
=
average_losses_across_data_parallel_group
(
[
lm_loss
,
sop_loss
])
return
loss
,
{
'lm loss'
:
averaged_losses
[
0
],
'sop loss'
:
averaged_losses
[
1
]}
else
:
loss
=
lm_loss
averaged_losses
=
average_losses_across_data_parallel_group
(
[
lm_loss
])
return
loss
,
{
'lm loss'
:
averaged_losses
[
0
]}
return
output_tensor
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
...
...
pretrain_gpt.py
View file @
3fc035d7
...
@@ -16,50 +16,28 @@
...
@@ -16,50 +16,28 @@
"""Pretrain GPT"""
"""Pretrain GPT"""
import
torch
import
torch
from
functools
import
partial
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
(
GPTModel
,
from
megatron.model
import
GPTModel
GPTModelFirstStage
,
GPTModelIntermediateStage
,
GPTModelLastStage
)
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
def
model_provider_pipelined
():
num_tokentypes
=
0
,
# Determine model based on position of stage in pipeline.
parallel_output
=
True
,
if
mpu
.
is_pipeline_first_stage
():
pre_process
=
pre_process
,
model
=
GPTModelFirstStage
(
num_tokentypes
=
0
)
post_process
=
post_process
elif
mpu
.
is_pipeline_last_stage
():
)
model
=
GPTModelLastStage
(
num_tokentypes
=
0
,
parallel_output
=
True
)
else
:
model
=
GPTModelIntermediateStage
(
num_tokentypes
=
0
)
return
model
args
=
get_args
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
model
=
[]
for
i
in
range
(
args
.
virtual_pipeline_model_parallel_size
):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
.
append
(
model_provider_pipelined
())
else
:
model
=
model_provider_pipelined
()
else
:
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
)
return
model
return
model
...
@@ -94,8 +72,18 @@ def get_batch(data_iterator):
...
@@ -94,8 +72,18 @@ def get_batch(data_iterator):
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
def
loss_func
(
loss_mask
,
output_tensor
):
losses
=
output_tensor
.
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
def
forward_step
(
data_iterator
,
model
,
input_tensor
):
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
"""Forward step."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -106,31 +94,10 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -106,31 +94,10 @@ def forward_step(data_iterator, model, input_tensor):
data_iterator
)
data_iterator
)
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
# Forward pass through the model.
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
if
mpu
.
is_pipeline_first_stage
():
labels
=
labels
)
assert
input_tensor
is
None
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
else
:
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
)
elif
mpu
.
is_pipeline_last_stage
():
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
,
labels
=
labels
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
)
if
mpu
.
is_pipeline_last_stage
():
losses
=
output_tensor
.
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
return
output_tensor
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
...
...
tasks/eval_utils.py
View file @
3fc035d7
...
@@ -17,13 +17,14 @@
...
@@ -17,13 +17,14 @@
import
os
import
os
import
time
import
time
from
functools
import
partial
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_last
,
is_last_rank
from
megatron
import
print_rank_last
,
is_last_rank
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.
training
import
communicate
from
megatron.
schedules
import
get_forward_backward_func
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
process_batch
from
tasks.finetune_utils
import
process_batch
...
@@ -38,7 +39,7 @@ def accuracy_func_provider(single_dataset_provider):
...
@@ -38,7 +39,7 @@ def accuracy_func_provider(single_dataset_provider):
for
datapath
in
datapaths
:
for
datapath
in
datapaths
:
dataset
=
single_dataset_provider
(
datapath
)
dataset
=
single_dataset_provider
(
datapath
)
dataloader
=
build_data_loader
(
dataloader
=
build_data_loader
(
dataset
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
dataset
,
args
.
orig_
micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
))
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
))
dataloaders
.
append
((
dataset
.
dataset_name
,
dataloader
))
dataloaders
.
append
((
dataset
.
dataset_name
,
dataloader
))
...
@@ -73,14 +74,66 @@ def accuracy_func_provider(single_dataset_provider):
...
@@ -73,14 +74,66 @@ def accuracy_func_provider(single_dataset_provider):
return
metrics_func
return
metrics_func
def
calculate_correct_answers
(
name
,
model
,
dataloader
,
def
calculate_correct_answers
(
name
,
model
,
dataloader
,
epoch
,
output_predictions
):
epoch
,
output_predictions
):
"""Calculate correct over total answers and return prediction if the
"""Calculate correct over total answers and return prediction if the
`output_predictions` is true."""
`output_predictions` is true."""
args
=
get_args
()
args
=
get_args
()
forward_backward_func
=
get_forward_backward_func
()
start_time
=
time
.
time
()
start_time
=
time
.
time
()
model
.
eval
()
for
m
in
model
:
saved_batch_size
=
args
.
micro_batch_size
m
.
eval
()
saved_micro_batch_size
=
args
.
micro_batch_size
saved_global_batch_size
=
args
.
global_batch_size
ds
=
dataloader
.
dataset
if
hasattr
(
ds
,
'sample_multiplier'
):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
sample_multiplier
=
ds
.
sample_multiplier
else
:
sample_multiplier
=
1
micro_batch_size_times_data_parallel
=
args
.
orig_micro_batch_size
*
args
.
data_parallel_size
num_micro_batches
=
args
.
orig_global_batch_size
//
micro_batch_size_times_data_parallel
def
loss_func
(
output_predictions
,
labels
,
output_tensor
):
logits
=
output_tensor
loss_dict
=
{}
# Add output predictions.
if
output_predictions
:
assert
False
loss_dict
[
'softmaxes'
]
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
logits
.
float
()).
data
.
cpu
().
numpy
().
tolist
()
loss_dict
[
'labels'
]
=
labels
.
data
.
cpu
().
numpy
().
tolist
()
loss_dict
[
'ids'
]
=
batch
[
'uid'
].
cpu
().
numpy
().
tolist
()
# Compute the correct answers.
predicted
=
torch
.
argmax
(
logits
,
dim
=-
1
)
corrects
=
(
predicted
==
labels
)
# Add to the counters.
loss_dict
[
'total'
]
=
labels
.
size
(
0
)
loss_dict
[
'correct'
]
=
corrects
.
sum
().
item
()
return
0
,
loss_dict
# defined inside to capture output_predictions
def
correct_answers_forward_step
(
batch
,
model
):
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
tokens
,
types
,
labels
,
attention_mask
=
process_batch
(
batch_
)
# Forward model.
args
=
get_args
()
output_tensor
=
model
(
tokens
,
attention_mask
,
tokentype_ids
=
types
)
return
output_tensor
,
partial
(
loss_func
,
output_predictions
,
labels
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# For all the batches in the dataset.
# For all the batches in the dataset.
total
=
0
total
=
0
...
@@ -92,60 +145,30 @@ def calculate_correct_answers(name, model, dataloader,
...
@@ -92,60 +145,30 @@ def calculate_correct_answers(name, model, dataloader,
labels
=
[]
labels
=
[]
ids
=
[]
ids
=
[]
for
_
,
batch
in
enumerate
(
dataloader
):
for
_
,
batch
in
enumerate
(
dataloader
):
# Run the model forward.
tokens
,
types
,
labels_
,
attention_mask
=
process_batch
(
batch
)
# For evaluation only mode we use drop_last = False to get all the
# For evaluation only mode we use drop_last = False to get all the
# samples, which means we might not have a full batch, so we
# samples, which means we might not have a full batch, so we
# adjust batch_size here to actual batch size of data
# adjust batch_size here to actual batch size of data
actual_batch_size
=
len
(
label
s_
)
actual_batch_size
=
len
(
batch
[
'
label
'
]
)
# ... applying sample_multiplier if necessary
# ... applying sample_multiplier if necessary
ds
=
dataloader
.
dataset
args
.
micro_batch_size
=
actual_batch_size
*
sample_multiplier
if
hasattr
(
ds
,
'sample_multiplier'
):
args
.
global_batch_size
=
actual_batch_size
*
sample_multiplier
*
num_micro_batches
actual_batch_size
*=
ds
.
sample_multiplier
args
.
micro_batch_size
=
actual_batch_size
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
input_tensor
=
None
# Forward model.
loss_dicts
=
forward_backward_func
(
correct_answers_forward_step
,
batch
,
model
,
if
mpu
.
is_pipeline_first_stage
():
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
assert
input_tensor
is
None
output_tensor
=
model
(
tokens
,
attention_mask
,
tokentype_ids
=
types
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
)
if
mpu
.
is_pipeline_last_stage
():
logits
=
output_tensor
# Add output predictions.
for
loss_dict
in
loss_dicts
:
if
output_predictions
:
if
output_predictions
:
softmaxes
.
extend
(
torch
.
nn
.
Softmax
(
dim
=-
1
)(
softmaxes
.
extend
(
loss_dict
[
'softmaxes'
])
logits
.
float
()).
data
.
cpu
().
numpy
().
tolist
())
labels
.
extend
(
loss_dict
[
'labels'
])
labels
.
extend
(
labels_
.
data
.
cpu
().
numpy
().
tolist
())
ids
.
extend
(
loss_dict
[
'ids'
])
ids
.
extend
(
batch
[
'uid'
].
cpu
().
numpy
().
tolist
())
total
+=
loss_dict
[
'total'
]
# Compute the correct answers.
correct
+=
loss_dict
[
'correct'
]
predicted
=
torch
.
argmax
(
logits
,
dim
=-
1
)
corrects
=
(
predicted
==
labels_
)
# Add to the counters.
for
m
in
model
:
total
+=
labels_
.
size
(
0
)
m
.
train
()
correct
+=
corrects
.
sum
().
item
()
args
.
micro_batch_size
=
saved_micro_batch_size
else
:
args
.
global_batch_size
=
saved_global_batch_size
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
model
.
train
()
args
.
micro_batch_size
=
saved_batch_size
# Reduce.
# Reduce.
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
...
...
tasks/finetune_utils.py
View file @
3fc035d7
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
"""Finetune utilities."""
"""Finetune utilities."""
from
functools
import
partial
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
...
@@ -46,7 +48,20 @@ def process_batch(batch):
...
@@ -46,7 +48,20 @@ def process_batch(batch):
return
tokens
,
types
,
labels
,
attention_mask
return
tokens
,
types
,
labels
,
attention_mask
def
_cross_entropy_forward_step
(
batch
,
model
,
input_tensor
):
def
cross_entropy_loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
# Cross-entropy loss.
loss_func
=
torch
.
nn
.
CrossEntropyLoss
()
loss
=
loss_func
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -60,25 +75,9 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
...
@@ -60,25 +75,9 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
# Forward model.
# Forward model.
if
mpu
.
is_pipeline_first_stage
():
output_tensor
=
model
(
tokens
,
attention_mask
,
tokentype_ids
=
types
)
assert
input_tensor
is
None
output_tensor
=
model
(
tokens
,
attention_mask
,
tokentype_ids
=
types
)
else
:
assert
input_tensor
is
not
None
output_tensor
=
model
(
input_tensor
,
attention_mask
)
if
mpu
.
is_pipeline_last_stage
():
logits
=
output_tensor
# Cross-entropy loss.
loss_func
=
torch
.
nn
.
CrossEntropyLoss
()
loss
=
loss_func
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
return
output_tensor
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
):
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
):
...
@@ -135,7 +134,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
...
@@ -135,7 +134,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
# This is necessary so pipeline transfers know what size they are
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# and the LR schedule, which is based on samples seen, gets set
# correctly.
# correctly.
args
.
orig_micro_batch_size
=
args
.
micro_batch_size
args
.
orig_global_batch_size
=
args
.
global_batch_size
if
hasattr
(
train_dataset
,
'sample_multiplier'
):
if
hasattr
(
train_dataset
,
'sample_multiplier'
):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
args
.
micro_batch_size
*=
train_dataset
.
sample_multiplier
args
.
micro_batch_size
*=
train_dataset
.
sample_multiplier
args
.
global_batch_size
*=
train_dataset
.
sample_multiplier
args
.
global_batch_size
*=
train_dataset
.
sample_multiplier
...
@@ -149,7 +155,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -149,7 +155,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
timers
=
get_timers
()
timers
=
get_timers
()
# Turn on training mode which enables dropout.
# Turn on training mode which enables dropout.
model
.
train
()
for
m
in
model
:
m
.
train
()
# Tracking loss.
# Tracking loss.
losses_dict_sum
=
{}
losses_dict_sum
=
{}
...
@@ -180,10 +187,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -180,10 +187,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration
=
0
start_iteration
=
0
# Train for one step.
# Train for one step.
losses_dict
,
skipped_iter
,
grad_norm
=
train_step
(
forward_step
,
out
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr_scheduler
)
batch
,
model
,
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
out
optimizer
,
lr_scheduler
)
iteration
+=
1
iteration
+=
1
# Logging.
# Logging.
...
@@ -195,7 +200,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -195,7 +200,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
iteration
,
iteration
,
optimizer
.
get_loss_scale
().
item
(),
optimizer
.
get_loss_scale
().
item
(),
report_memory_flag
,
skipped_iter
,
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
)
grad_norm
,
params_norm
,
num_zeros_in_grad
)
# Autoresume
# Autoresume
if
args
.
adlr_autoresume
and
\
if
args
.
adlr_autoresume
and
\
...
@@ -231,6 +236,9 @@ def finetune(train_valid_datasets_provider, model_provider,
...
@@ -231,6 +236,9 @@ def finetune(train_valid_datasets_provider, model_provider,
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
assert
args
.
rampup_batch_size
is
None
,
\
'batch size scaling is not supported for finetuning'
# Train and validation data loaders.
# Train and validation data loaders.
timers
(
'train/valid/test dataset/dataloder'
).
start
()
timers
(
'train/valid/test dataset/dataloder'
).
start
()
if
args
.
epochs
>
0
:
if
args
.
epochs
>
0
:
...
...
tasks/glue/finetune.py
View file @
3fc035d7
...
@@ -19,7 +19,7 @@ from megatron import get_args
...
@@ -19,7 +19,7 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model.classification
import
Classification
,
ClassificationFirstStage
,
ClassificationIntermediateStage
,
ClassificationLastStage
from
megatron.model.classification
import
Classification
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.finetune_utils
import
finetune
from
tasks.finetune_utils
import
finetune
...
@@ -39,25 +39,14 @@ def glue_classification(num_classes, Dataset,
...
@@ -39,25 +39,14 @@ def glue_classification(num_classes, Dataset,
return
train_dataset
,
valid_dataset
return
train_dataset
,
valid_dataset
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'building classification model for {} ...'
.
format
(
print_rank_0
(
'building classification model for {} ...'
.
format
(
args
.
task
))
args
.
task
))
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
model
=
Classification
(
num_classes
=
num_classes
,
num_tokentypes
=
2
,
# Determine model based on position of stage in pipeline.
pre_process
=
pre_process
,
post_process
=
post_process
)
if
mpu
.
is_pipeline_first_stage
():
model
=
ClassificationFirstStage
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
ClassificationLastStage
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
else
:
model
=
ClassificationIntermediateStage
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
else
:
model
=
Classification
(
num_classes
=
num_classes
,
num_tokentypes
=
2
)
return
model
return
model
...
...
tasks/main.py
View file @
3fc035d7
...
@@ -70,6 +70,11 @@ if __name__ == '__main__':
...
@@ -70,6 +70,11 @@ if __name__ == '__main__':
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
args
=
get_args
()
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for downstream tasks."
)
exit
()
if
args
.
task
==
'RACE'
:
if
args
.
task
==
'RACE'
:
from
race.finetune
import
main
from
race.finetune
import
main
elif
args
.
task
in
[
'MNLI'
,
'QQP'
]:
elif
args
.
task
in
[
'MNLI'
,
'QQP'
]:
...
...
tasks/race/data.py
View file @
3fc035d7
...
@@ -39,6 +39,8 @@ class RaceDataset(Dataset):
...
@@ -39,6 +39,8 @@ class RaceDataset(Dataset):
print_rank_0
(
' >> total number of samples: {}'
.
format
(
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
len
(
self
.
samples
)))
# This indicates that each "sample" has multiple samples that
# will collapse into batch dimension
self
.
sample_multiplier
=
NUM_CHOICES
self
.
sample_multiplier
=
NUM_CHOICES
def
__len__
(
self
):
def
__len__
(
self
):
...
...
tasks/race/finetune.py
View file @
3fc035d7
...
@@ -19,7 +19,7 @@ from megatron import get_args
...
@@ -19,7 +19,7 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model.multiple_choice
import
MultipleChoice
,
MultipleChoiceFirstStage
,
MultipleChoiceIntermediateStage
,
MultipleChoiceLastStage
from
megatron.model.multiple_choice
import
MultipleChoice
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.finetune_utils
import
finetune
from
tasks.finetune_utils
import
finetune
from
tasks.race.data
import
RaceDataset
from
tasks.race.data
import
RaceDataset
...
@@ -38,20 +38,13 @@ def train_valid_datasets_provider():
...
@@ -38,20 +38,13 @@ def train_valid_datasets_provider():
return
train_dataset
,
valid_dataset
return
train_dataset
,
valid_dataset
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
print_rank_0
(
'building multichoice model for RACE ...'
)
print_rank_0
(
'building multichoice model for RACE ...'
)
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
model
=
MultipleChoice
(
num_tokentypes
=
2
,
# Determine model based on position of stage in pipeline.
pre_process
=
pre_process
,
if
mpu
.
is_pipeline_first_stage
():
post_process
=
post_process
)
model
=
MultipleChoiceFirstStage
(
num_tokentypes
=
2
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
MultipleChoiceLastStage
(
num_tokentypes
=
2
)
else
:
model
=
MultipleChoiceIntermediateStage
(
num_tokentypes
=
2
)
else
:
model
=
MultipleChoice
(
num_tokentypes
=
2
)
return
model
return
model
...
...
tasks/zeroshot_gpt/evaluate.py
View file @
3fc035d7
...
@@ -24,19 +24,24 @@ from megatron import print_rank_0, is_last_rank
...
@@ -24,19 +24,24 @@ from megatron import print_rank_0, is_last_rank
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.model
import
GPTModel
,
GPTModelFirstStage
,
GPTModelLastStage
,
GPTModelIntermediateStage
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
,
communicate
from
megatron.training
import
get_model
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
get_ltor_masks_and_position_ids
,
unwrap_model
from
megatron.p2p_communication
import
recv_forward
,
send_forward
from
tasks.finetune_utils
import
build_data_loader
from
tasks.finetune_utils
import
build_data_loader
from
.datasets
import
build_dataset
from
.datasets
import
build_dataset
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
def
get_model_provider
(
eval_metric
):
def
get_model_provider
(
eval_metric
):
"""Based on evaluation metric set the parallel-output flag and
"""Based on evaluation metric set the parallel-output flag and
return the model provider."""
return the model provider."""
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
if
eval_metric
==
'loss'
:
if
eval_metric
==
'loss'
:
...
@@ -48,17 +53,8 @@ def get_model_provider(eval_metric):
...
@@ -48,17 +53,8 @@ def get_model_provider(eval_metric):
'is not supported.'
.
format
(
eval_metric
))
'is not supported.'
.
format
(
eval_metric
))
print_rank_0
(
'building GPT model ...'
)
print_rank_0
(
'building GPT model ...'
)
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
parallel_output
,
# Determine model based on position of stage in pipeline.
pre_process
=
pre_process
,
post_process
=
post_process
)
if
mpu
.
is_pipeline_first_stage
():
model
=
GPTModelFirstStage
(
num_tokentypes
=
0
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
GPTModelLastStage
(
parallel_output
=
parallel_output
,
num_tokentypes
=
0
)
else
:
model
=
GPTModelIntermediateStage
(
num_tokentypes
=
0
)
else
:
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
parallel_output
)
return
model
return
model
...
@@ -97,33 +93,15 @@ def forward_step(batch, model, eval_metric):
...
@@ -97,33 +93,15 @@ def forward_step(batch, model, eval_metric):
args
=
get_args
()
args
=
get_args
()
args
.
micro_batch_size
=
len
(
labels
)
args
.
micro_batch_size
=
len
(
labels
)
# Forward model.
input_tensor
=
recv_forward
()
if
not
mpu
.
is_pipeline_first_stage
():
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_backward
=
False
)
else
:
input_tensor
=
None
# Forward pass through the model.
# Forward pass through the model.
if
mpu
.
is_pipeline_first_stage
():
unwrapped_model
=
unwrap_model
(
assert
input_tensor
is
None
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
mpu
.
is_pipeline_last_stage
():
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
else
:
output
=
model
(
tokens
,
position_ids
,
attention_mask
)
else
:
assert
input_tensor
is
not
None
output
=
model
(
input_tensor
,
attention_mask
)
if
not
mpu
.
is_pipeline_last_stage
():
send_forward
(
output
)
communicate
(
tensor_send_next
=
output
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_backward
=
False
)
return
None
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
# For loss, return the unreduced loss.
# For loss, return the unreduced loss.
...
@@ -214,6 +192,10 @@ def main():
...
@@ -214,6 +192,10 @@ def main():
"""Main program."""
"""Main program."""
args
=
get_args
()
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
if
args
.
task
==
'LAMBADA'
:
if
args
.
task
==
'LAMBADA'
:
eval_metric
=
'accuracy'
eval_metric
=
'accuracy'
elif
args
.
task
==
'WIKITEXT103'
:
elif
args
.
task
==
'WIKITEXT103'
:
...
@@ -227,6 +209,9 @@ def main():
...
@@ -227,6 +209,9 @@ def main():
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
# Data stuff.
# Data stuff.
dataset
=
build_dataset
(
args
.
task
)
dataset
=
build_dataset
(
args
.
task
)
dataloader
=
build_data_loader
(
dataset
,
args
.
micro_batch_size
,
dataloader
=
build_data_loader
(
dataset
,
args
.
micro_batch_size
,
...
...
tools/generate_samples_gpt.py
View file @
3fc035d7
...
@@ -26,33 +26,19 @@ from megatron import get_tokenizer
...
@@ -26,33 +26,19 @@ from megatron import get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
(
GPTModel
,
from
megatron.model
import
GPTModel
GPTModelFirstStage
,
GPTModelLastStage
,
GPTModelIntermediateStage
)
from
megatron.training
import
get_model
from
megatron.training
import
get_model
from
megatron.text_generation_utils
import
generate_and_write_samples_unconditional
from
megatron.text_generation_utils
import
generate_and_write_samples_unconditional
from
megatron.text_generation_utils
import
generate_samples_input_from_file
from
megatron.text_generation_utils
import
generate_samples_input_from_file
from
megatron.text_generation_utils
import
generate_samples_interactive
from
megatron.text_generation_utils
import
generate_samples_interactive
def
model_provider
():
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
print_rank_0
(
'building GPT model ...'
)
args
=
get_args
()
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
False
,
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
pre_process
=
pre_process
,
post_process
=
post_process
)
# Determine model based on position of stage in pipeline.
if
mpu
.
is_pipeline_first_stage
():
model
=
GPTModelFirstStage
(
num_tokentypes
=
0
)
elif
mpu
.
is_pipeline_last_stage
():
model
=
GPTModelLastStage
(
num_tokentypes
=
0
,
parallel_output
=
False
)
else
:
model
=
GPTModelIntermediateStage
(
num_tokentypes
=
0
)
else
:
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
False
)
return
model
return
model
...
@@ -96,12 +82,20 @@ def main():
...
@@ -96,12 +82,20 @@ def main():
'no_load_rng'
:
True
,
'no_load_rng'
:
True
,
'no_load_optim'
:
True
})
'no_load_optim'
:
True
})
args
=
get_args
()
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up model and load checkpoint.
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
model
=
get_model
(
model_provider
)
args
=
get_args
()
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
# Generate samples.
# Generate samples.
if
args
.
num_samples
==
0
:
if
args
.
num_samples
==
0
:
args
.
micro_batch_size
=
1
args
.
micro_batch_size
=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment