Commit 7c3d8b7a authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merge branch 'main' into main_dedup

parents ee7b19e7 f32a638d
[Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel (tensor and pipeline), and multi-node pre-training of [GPT](https://arxiv.org/abs/2005.14165) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision. Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf) and [2](https://arxiv.org/pdf/2104.04473.pdf)) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel (tensor and pipeline), and multi-node pre-training of [GPT](https://arxiv.org/abs/2005.14165) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision.
Below are some of the projects where we have directly used Megatron: Below are some of the projects where we have directly used Megatron:
* [BERT and GPT Studies Using Megatron](https://arxiv.org/pdf/1909.08053.pdf) * [BERT and GPT Studies Using Megatron](https://arxiv.org/pdf/1909.08053.pdf)
...@@ -8,15 +8,15 @@ Below are some of the projects where we have directly used Megatron: ...@@ -8,15 +8,15 @@ Below are some of the projects where we have directly used Megatron:
* [Local Knowledge Powered Conversational Agents](https://arxiv.org/abs/2010.10150) * [Local Knowledge Powered Conversational Agents](https://arxiv.org/abs/2010.10150)
* [MEGATRON-CNTRL: Controllable Story Generation with External Knowledge Using Large-Scale Language Models](https://www.aclweb.org/anthology/2020.emnlp-main.226.pdf) * [MEGATRON-CNTRL: Controllable Story Generation with External Knowledge Using Large-Scale Language Models](https://www.aclweb.org/anthology/2020.emnlp-main.226.pdf)
* [RACE Reading Comprehension Dataset Leaderboard](http://www.qizhexie.com/data/RACE_leaderboard.html) * [RACE Reading Comprehension Dataset Leaderboard](http://www.qizhexie.com/data/RACE_leaderboard.html)
* [Scaling Language Model Training to a Trillion Parameters Using Megatron](https://arxiv.org/pdf/2104.04473.pdf)
* [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf) * [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf)
Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. The table below shows the model configurations along with the achieved FLOPs per second (both per GPU and aggregate over all GPUs). Note that the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging. Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. The table below shows the model configurations along with the achieved FLOPs (both per GPU and aggregate over all GPUs). Note that the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging.
![Cases](images/cases_jan2021.png) ![Cases](images/cases_april2021.png)
The following figures show achieved percentage of theoretical peak FLOPs and achieved aggregate petaFLOPs per second as a function of number of GPUs. All the cases from 1 billion to 1 trillion achieve more than 41% half precision utilization, which is high for an end-to-end application. We observe that initially as the model parallel size increases, utilization slightly decreases; as hidden size increases for larger models, utilization starts increasing and reaches 49% for the largest model. We also note that achieved aggregate petaFLOPs per second across all GPUs increases almost linearly with number of GPUs, demonstrating good weak scaling. All the cases from 1 billion to 1 trillion parameters achieve more than 43% half precision utilization, which is high for an end-to-end application. We observe that initially the utilization remains constant but as hidden size increases for larger models, utilization starts increasing and reaches 52% for the largest model. We also note that achieved aggregate petaFLOPs across all GPUs increases almost linearly with number of GPUs, demonstrating good weak scaling.
![Model Parallel Scaling](images/scaling.png)
# Contents # Contents
* [Contents](#contents) * [Contents](#contents)
......
default: cases.png scaling-mp.png scaling-dp.png
# for some reason the size option to convert in scaling.tex doesn't work, manually do it after
cases.png scaling-mp.png scaling-dp.png: tables.tex
latex --shell-escape $<
convert tables-1.png -resize 650 cases.png
convert tables-2.png -resize 600 scaling-mp.png
convert tables-3.png -resize 350 scaling-dp.png
clean:
rm -rf *.aux *.log *.dvi *.ps
rm -rf tables-*.png
\documentclass[multi,convert]{standalone}
\usepackage{multirow}
\standaloneenv{tabular}
\begin{document}
\begin{tabular}{cccccc}
Case & Hidden Size & Attention Heads & Layers & Parameters (billions) & Model Parallel Partitions \\
\hline
1B & 1920 & 15 & 24 & 1.16 & 1 \\
2B & 2304 & 18 & 30 & 2.03 & 2 \\
4B & 3072 & 24 & 36 & 4.24 & 4 \\
8B & 4096 & 32 & 42 & 8.67 & 8 \\
\end{tabular}
\begin{tabular}{cc|ccc|ccc}
& & \multicolumn{3}{c|}{\textbf{DGX-2 (V100) batch size 8}} & \multicolumn{3}{c}{\textbf{DGX-A100 batch size 16}} \\
\hline
\multirow{2}{*}{Case} & Number of & Iteration & \multirow{2}{*}{Scaling} & TeraFLOPs & Iteration & \multirow{2}{*}{Scaling} & TeraFLOPs \\
& GPUs & Time (ms) & & per GPU & Time (ms) & & per GPU \\
\hline
1B & 1 & 1121 & 100.0\% & 71.9 & 1076 & 100\% & 149.8 \\
2B & 2 & 1093 & 89.6\% & 64.2 & 1026 & 91.7\% & 136.8 \\
4B & 4 & 1238 & 82.5\% & 58.5 & 1162 & 84.5\% & 124.7 \\
8B & 8 & 1407 & 74.3\% & 52.2 & 1343 & 74.7\% & 109.3 \\
\end{tabular}
\begin{tabular}{cc|ccc}
& & \multicolumn{3}{c}{\textbf{DGX-A100 batch size 2048}} \\
\hline
\multirow{2}{*}{Case} & Number of & Iteration & \multirow{2}{*}{Scaling} & TeraFLOPs \\
& GPUs & Time (ms) & & per GPU \\
\hline
1B & 128 & 1153 & 93.3\% & 139.8 \\
2B & 256 & 1101 & 85.5\% & 127.5 \\
4B & 512 & 1242 & 79.0\% & 116.7 \\
8B & 1024 & 1380 & 72.7\% & 106.5 \\
\end{tabular}
\end{document}
...@@ -136,6 +136,13 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -136,6 +136,13 @@ def parse_args(extra_args_provider=None, defaults={},
if args.bf16: if args.bf16:
assert not args.fp16 assert not args.fp16
args.params_dtype = torch.bfloat16 args.params_dtype = torch.bfloat16
# bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.', flush=True)
if args.rank == 0: if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype), print('using {} for parameters ...'.format(args.params_dtype),
......
...@@ -15,16 +15,8 @@ ...@@ -15,16 +15,8 @@
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from .distributed import * from .distributed import DistributedDataParallel
from .bert_model import (BertModel, from .bert_model import BertModel
BertModelFirstStage, from .gpt_model import GPTModel
BertModelIntermediateStage,
BertModelLastStage)
from .gpt_model import (GPTModel,
GPTModelFirstStage,
GPTModelIntermediateStage,
GPTModelLastStage)
from .language_model import get_language_model from .language_model import get_language_model
from .module import Float16Module from .module import Float16Module
...@@ -121,17 +121,23 @@ def post_language_model_processing(lm_output, pooled_output, ...@@ -121,17 +121,23 @@ def post_language_model_processing(lm_output, pooled_output,
return lm_loss, binary_logits return lm_loss, binary_logits
class BertModelBase(MegatronModule): class BertModel(MegatronModule):
"""Bert Language model.""" """Bert Language model."""
def __init__(self, num_tokentypes=2, add_binary_head=True, def __init__(self,
parallel_output=True): num_tokentypes=2,
super(BertModelBase, self).__init__() add_binary_head=True,
parallel_output=True,
pre_process=True,
post_process=True):
super(BertModel, self).__init__()
args = get_args() args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head self.add_binary_head = add_binary_head
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std, scaled_init_method = scaled_init_method_normal(args.init_method_std,
...@@ -142,10 +148,12 @@ class BertModelBase(MegatronModule): ...@@ -142,10 +148,12 @@ class BertModelBase(MegatronModule):
add_pooler=self.add_binary_head, add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings(init_method_normal)
if mpu.is_pipeline_last_stage(): if self.post_process:
self.lm_head = BertLMHead( self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0), self.word_embeddings_weight().size(0),
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
...@@ -156,26 +164,30 @@ class BertModelBase(MegatronModule): ...@@ -156,26 +164,30 @@ class BertModelBase(MegatronModule):
init_method) init_method)
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, bert_model_input, attention_mask, def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None): tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
kwargs = {} lm_output = self.language_model(
if mpu.is_pipeline_first_stage(): input_ids,
input_ids = bert_model_input position_ids,
position_ids = bert_position_ids(input_ids) extended_attention_mask,
args = [input_ids, position_ids, extended_attention_mask] tokentype_ids=tokentype_ids
kwargs['tokentype_ids'] = tokentype_ids )
else:
args = [bert_model_input, extended_attention_mask] if self.post_process and self.add_binary_head:
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage() and self.add_binary_head:
lm_output, pooled_output = lm_output lm_output, pooled_output = lm_output
else: else:
pooled_output = None pooled_output = None
if mpu.is_pipeline_last_stage(): if self.post_process:
return post_language_model_processing(lm_output, pooled_output, return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head, self.lm_head, self.binary_head,
lm_labels, lm_labels,
...@@ -194,15 +206,15 @@ class BertModelBase(MegatronModule): ...@@ -194,15 +206,15 @@ class BertModelBase(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage(): if self.post_process:
state_dict_[self._lm_head_key] \ state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint( = self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage() and self.add_binary_head: if self.post_process and self.add_binary_head:
state_dict_[self._binary_head_key] \ state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars) = self.binary_head.state_dict(destination, prefix, keep_vars)
# Save word_embeddings. # Save word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
...@@ -212,74 +224,13 @@ class BertModelBase(MegatronModule): ...@@ -212,74 +224,13 @@ class BertModelBase(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if mpu.is_pipeline_last_stage(): if self.post_process:
self.lm_head.load_state_dict( self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict) state_dict[self._lm_head_key], strict=strict)
if mpu.is_pipeline_last_stage() and self.add_binary_head: if self.post_process and self.add_binary_head:
self.binary_head.load_state_dict( self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict) state_dict[self._binary_head_key], strict=strict)
# Load word_embeddings. # Load word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict( self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict) state_dict[self._word_embeddings_for_head_key], strict=strict)
class BertModel(BertModelBase):
def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True):
super(BertModel, self).__init__(
num_tokentypes=num_tokentypes,
add_binary_head=add_binary_head,
parallel_output=parallel_output)
def forward(self, input_ids, attention_mask,
tokentype_ids=None, lm_labels=None):
return super(BertModel, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids,
lm_labels=lm_labels)
class BertModelFirstStage(BertModelBase):
def __init__(self, num_tokentypes=2):
super(BertModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(BertModelFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class BertModelIntermediateStage(BertModelBase):
def __init__(self, num_tokentypes=2):
super(BertModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(BertModelIntermediateStage, self).forward(
hidden_state,
attention_mask)
class BertModelLastStage(BertModelBase):
def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True):
super(BertModelLastStage, self).__init__(
num_tokentypes=num_tokentypes,
add_binary_head=add_binary_head,
parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask,
lm_labels=None):
return super(BertModelLastStage, self).forward(
hidden_state,
attention_mask,
lm_labels=lm_labels)
...@@ -28,13 +28,19 @@ from megatron.model.utils import scaled_init_method_normal ...@@ -28,13 +28,19 @@ from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
class ClassificationBase(MegatronModule): class Classification(MegatronModule):
def __init__(self, num_classes, num_tokentypes=2): def __init__(self,
super(ClassificationBase, self).__init__(share_word_embeddings=False) num_classes,
num_tokentypes=2,
pre_process=True,
post_process=True):
super(Classification, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
self.num_classes = num_classes self.num_classes = num_classes
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
...@@ -43,31 +49,36 @@ class ClassificationBase(MegatronModule): ...@@ -43,31 +49,36 @@ class ClassificationBase(MegatronModule):
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
# Multi-choice head. # Multi-choice head.
if mpu.is_pipeline_last_stage(): if self.post_process:
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout) self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(args.hidden_size, self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes, self.num_classes,
init_method) init_method)
self._classification_head_key = 'classification_head' self._classification_head_key = 'classification_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None): def forward(self, model_input, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = model_input
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
kwargs = {} if self.post_process:
if mpu.is_pipeline_first_stage():
input_ids = model_input
position_ids = bert_position_ids(input_ids)
args = [input_ids, position_ids, extended_attention_mask]
kwargs['tokentype_ids'] = tokentype_ids
else:
args = [model_input, extended_attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage():
_, pooled_output = lm_output _, pooled_output = lm_output
classification_output = self.classification_dropout(pooled_output) classification_output = self.classification_dropout(pooled_output)
classification_logits = self.classification_head(classification_output) classification_logits = self.classification_head(classification_output)
...@@ -87,7 +98,7 @@ class ClassificationBase(MegatronModule): ...@@ -87,7 +98,7 @@ class ClassificationBase(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage(): if self.post_process:
state_dict_[self._classification_head_key] \ state_dict_[self._classification_head_key] \
= self.classification_head.state_dict( = self.classification_head.state_dict(
destination, prefix, keep_vars) destination, prefix, keep_vars)
...@@ -98,7 +109,7 @@ class ClassificationBase(MegatronModule): ...@@ -98,7 +109,7 @@ class ClassificationBase(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if mpu.is_pipeline_last_stage(): if self.post_process:
if self._classification_head_key in state_dict: if self._classification_head_key in state_dict:
self.classification_head.load_state_dict( self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict) state_dict[self._classification_head_key], strict=strict)
...@@ -106,55 +117,3 @@ class ClassificationBase(MegatronModule): ...@@ -106,55 +117,3 @@ class ClassificationBase(MegatronModule):
print_rank_last('***WARNING*** could not find {} in the checkpoint, ' print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format( 'initializing to random'.format(
self._classification_head_key)) self._classification_head_key))
class Classification(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(Classification, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(Classification, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class ClassificationFirstStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationFirstStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(ClassificationFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class ClassificationIntermediateStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationIntermediateStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(ClassificationIntermediateStage, self).forward(
hidden_state,
attention_mask)
class ClassificationLastStage(ClassificationBase):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationLastStage, self).__init__(
num_classes, num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(ClassificationLastStage, self).forward(
hidden_state,
attention_mask)
...@@ -57,14 +57,20 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -57,14 +57,20 @@ def post_language_model_processing(lm_output, labels, logit_weights,
return loss return loss
class GPTModelBase(MegatronModule): class GPTModel(MegatronModule):
"""GPT-2 Language model.""" """GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self,
super(GPTModelBase, self).__init__() num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True):
super(GPTModel, self).__init__()
args = get_args() args = get_args()
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
...@@ -73,24 +79,28 @@ class GPTModelBase(MegatronModule): ...@@ -73,24 +79,28 @@ class GPTModelBase(MegatronModule):
encoder_attn_mask_type=AttnMaskType.causal, encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std), init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings(init_method_normal)
def forward(self, gpt_model_input, attention_mask, labels=None, def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value} lm_output = self.language_model(
if mpu.is_pipeline_first_stage(): input_ids,
(input_ids, position_ids) = gpt_model_input position_ids,
args = [input_ids, position_ids, attention_mask] attention_mask,
kwargs['tokentype_ids'] = tokentype_ids layer_past=layer_past,
else: get_key_value=get_key_value)
args = [gpt_model_input, attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage(): if self.post_process:
return post_language_model_processing( return post_language_model_processing(
lm_output, labels, lm_output, labels,
self.word_embeddings_weight(), self.word_embeddings_weight(),
...@@ -109,7 +119,7 @@ class GPTModelBase(MegatronModule): ...@@ -109,7 +119,7 @@ class GPTModelBase(MegatronModule):
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
# Save word_embeddings. # Save word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars) = self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
...@@ -118,79 +128,9 @@ class GPTModelBase(MegatronModule): ...@@ -118,79 +128,9 @@ class GPTModelBase(MegatronModule):
"""Customized load.""" """Customized load."""
# Load word_embeddings. # Load word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage(): if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict( self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict) state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict: if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key] state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict) self.language_model.load_state_dict(state_dict, strict=strict)
class GPTModel(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPTModel, self).__init__(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
return super(GPTModel, self).forward(
(input_ids, position_ids),
attention_mask,
labels=labels,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
class GPTModelFirstStage(GPTModelBase):
def __init__(self, num_tokentypes=0):
super(GPTModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
return super(GPTModelFirstStage, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
class GPTModelIntermediateStage(GPTModelBase):
def __init__(self, num_tokentypes=0):
super(GPTModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask,
layer_past=None, get_key_value=False):
return super(GPTModelIntermediateStage, self).forward(
hidden_state,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
class GPTModelLastStage(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPTModelLastStage, self).__init__(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask, labels=None,
layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
return super(GPTModelLastStage, self).forward(
hidden_state,
attention_mask,
labels=labels,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
...@@ -46,7 +46,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -46,7 +46,8 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None, encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False, scaled_init_method=None, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal): decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True, post_process=True):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args() args = get_args()
...@@ -58,26 +59,17 @@ def get_language_model(num_tokentypes, add_pooler, ...@@ -58,26 +59,17 @@ def get_language_model(num_tokentypes, add_pooler,
args.num_layers) args.num_layers)
# Language model. # Language model.
args = [init_method, scaled_init_method, encoder_attn_mask_type] language_model = TransformerLanguageModel(
kwargs = {} init_method,
cls = None scaled_init_method,
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): encoder_attn_mask_type,
cls = TransformerLanguageModel num_tokentypes=num_tokentypes,
kwargs['num_tokentypes'] = num_tokentypes add_decoder=add_decoder,
kwargs['add_decoder'] = add_decoder decoder_attn_mask_type=decoder_attn_mask_type,
kwargs['decoder_attn_mask_type'] = decoder_attn_mask_type add_pooler=add_pooler,
kwargs['add_pooler'] = add_pooler pre_process=pre_process,
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage(): post_process=post_process
cls = TransformerLanguageModelFirstStage )
kwargs['num_tokentypes'] = num_tokentypes
elif not mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelLastStage
kwargs['add_pooler'] = add_pooler
else:
cls = TransformerLanguageModelIntermediateStage
# Language model.
language_model = cls(*args, **kwargs)
# key used for checkpoints. # key used for checkpoints.
language_model_key = 'language_model' language_model_key = 'language_model'
...@@ -263,7 +255,7 @@ class Embedding(MegatronModule): ...@@ -263,7 +255,7 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True) 'checkpoint but could not find it', flush=True)
class TransformerLanguageModelBase(MegatronModule): class TransformerLanguageModel(MegatronModule):
"""Transformer language model. """Transformer language model.
Arguments: Arguments:
...@@ -283,10 +275,14 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -283,10 +275,14 @@ class TransformerLanguageModelBase(MegatronModule):
num_tokentypes=0, num_tokentypes=0,
add_decoder=False, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal, decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False): add_pooler=False,
super(TransformerLanguageModelBase, self).__init__() pre_process=True,
post_process=True):
super(TransformerLanguageModel, self).__init__()
args = get_args() args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = init_method self.init_method = init_method
...@@ -296,7 +292,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -296,7 +292,7 @@ class TransformerLanguageModelBase(MegatronModule):
self.add_pooler = add_pooler self.add_pooler = add_pooler
# Embeddings. # Embeddings.
if mpu.is_pipeline_first_stage(): if self.pre_process:
self.embedding = Embedding(self.hidden_size, self.embedding = Embedding(self.hidden_size,
args.padded_vocab_size, args.padded_vocab_size,
args.max_position_embeddings, args.max_position_embeddings,
...@@ -309,7 +305,10 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -309,7 +305,10 @@ class TransformerLanguageModelBase(MegatronModule):
self.encoder = ParallelTransformer( self.encoder = ParallelTransformer(
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type) self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process
)
self._encoder_key = 'encoder' self._encoder_key = 'encoder'
# Decoder # Decoder
...@@ -323,26 +322,29 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -323,26 +322,29 @@ class TransformerLanguageModelBase(MegatronModule):
self_attn_mask_type=self.decoder_attn_mask_type) self_attn_mask_type=self.decoder_attn_mask_type)
self._decoder_key = 'decoder' self._decoder_key = 'decoder'
if mpu.is_pipeline_last_stage(): if self.post_process:
# Pooler. # Pooler.
if self.add_pooler: if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler' self._pooler_key = 'pooler'
def forward(self, enc_language_model_input, enc_attn_mask, def set_input_tensor(self, input_tensor):
dec_language_model_input=None, dec_attn_mask=None, """ See megatron.model.transformer.set_input_tensor()"""
self.encoder.set_input_tensor(input_tensor)
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0, get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
# Embeddings. # Embeddings.
if mpu.is_pipeline_first_stage(): if self.pre_process:
(input_ids, position_ids) = enc_language_model_input embedding_output = self.embedding(enc_input_ids, enc_position_ids,
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
encoder_input = embedding_output encoder_input = embedding_output
else: else:
encoder_input = enc_language_model_input encoder_input = None
# encoder. # encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
...@@ -353,7 +355,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -353,7 +355,7 @@ class TransformerLanguageModelBase(MegatronModule):
else: else:
encoder_output = enc_hidden_states.to(encoder_input.dtype) encoder_output = enc_hidden_states.to(encoder_input.dtype)
if mpu.is_pipeline_last_stage(): if self.post_process:
if self.add_pooler: if self.add_pooler:
pooled_output = self.pooler(encoder_output, pooled_output = self.pooler(encoder_output,
pooling_sequence_index) pooling_sequence_index)
...@@ -362,13 +364,12 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -362,13 +364,12 @@ class TransformerLanguageModelBase(MegatronModule):
# output. For example, it is helpful to compute # output. For example, it is helpful to compute
# similarity between two sequences by average pooling # similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden: if not self.add_decoder or output_enc_hidden:
if self.add_pooler and mpu.is_pipeline_last_stage(): if self.add_pooler and self.post_process:
return encoder_output, pooled_output return encoder_output, pooled_output
else: else:
return encoder_output return encoder_output
# Decoder Embedding # Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids, dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids) dec_position_ids)
# decoder # decoder
...@@ -379,7 +380,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -379,7 +380,7 @@ class TransformerLanguageModelBase(MegatronModule):
encoder_output=encoder_output, encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask) enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler and mpu.is_pipeline_last_stage(): if self.add_pooler and self.post_process:
return decoder_output, encoder_output, pooled_output return decoder_output, encoder_output, pooled_output
else: else:
return decoder_output, encoder_output return decoder_output, encoder_output
...@@ -389,14 +390,14 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -389,14 +390,14 @@ class TransformerLanguageModelBase(MegatronModule):
"""For easy load.""" """For easy load."""
state_dict_ = {} state_dict_ = {}
if mpu.is_pipeline_first_stage(): if self.pre_process:
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._encoder_key] \ state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint( = self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage(): if self.post_process:
if self.add_pooler: if self.add_pooler:
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint( = self.pooler.state_dict_for_save_checkpoint(
...@@ -412,7 +413,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -412,7 +413,7 @@ class TransformerLanguageModelBase(MegatronModule):
"""Customized load.""" """Customized load."""
# Embedding. # Embedding.
if mpu.is_pipeline_first_stage(): if self.pre_process:
if self._embedding_key in state_dict: if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key] state_dict_ = state_dict[self._embedding_key]
else: else:
...@@ -448,7 +449,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -448,7 +449,7 @@ class TransformerLanguageModelBase(MegatronModule):
self.encoder.load_state_dict(state_dict_, strict=strict) self.encoder.load_state_dict(state_dict_, strict=strict)
if mpu.is_pipeline_last_stage(): if self.post_process:
# pooler # pooler
if self.add_pooler: if self.add_pooler:
assert 'pooler' in state_dict, \ assert 'pooler' in state_dict, \
...@@ -461,124 +462,3 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -461,124 +462,3 @@ class TransformerLanguageModelBase(MegatronModule):
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key], self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict) strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase):
"""Transformer language model (see TransformerLanguageModelBase
for description of arguments).
"""
def __init__(self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
decoder_attn_mask_type=AttnMaskType.causal,
add_decoder=False,
add_pooler=False):
super(TransformerLanguageModel, self).__init__(
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler)
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
return super(TransformerLanguageModel, self).forward(
(enc_input_ids, enc_position_ids),
enc_attn_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
output_enc_hidden=output_enc_hidden
)
class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
"""Transformer language model, first stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__(
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
return super(TransformerLanguageModelFirstStage, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value
)
class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
"""Transformer language model, intermediate stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
init_method,
output_layer_init_method,
encoder_attn_mask_type):
super(TransformerLanguageModelIntermediateStage, self).__init__(
init_method,
output_layer_init_method,
encoder_attn_mask_type)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False):
return super(TransformerLanguageModelIntermediateStage, self).forward(
hidden_states,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value
)
class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
"""Transformer language model, final stage (see
TransformerLanguageModelBase for description of arguments).
"""
def __init__(self,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
init_method,
output_layer_init_method,
encoder_attn_mask_type,
add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False,
pooling_sequence_index=0):
return super(TransformerLanguageModelLastStage, self).forward(
hidden_states,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index,
)
...@@ -28,13 +28,18 @@ from megatron.model.utils import scaled_init_method_normal ...@@ -28,13 +28,18 @@ from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
class MultipleChoiceBase(MegatronModule): class MultipleChoice(MegatronModule):
def __init__(self, num_tokentypes=2): def __init__(self,
super(MultipleChoiceBase, self).__init__(share_word_embeddings=False) num_tokentypes=2,
pre_process=True,
post_process=True):
super(MultipleChoice, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.pre_process = pre_process
self.post_process = post_process
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
...@@ -42,15 +47,21 @@ class MultipleChoiceBase(MegatronModule): ...@@ -42,15 +47,21 @@ class MultipleChoiceBase(MegatronModule):
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
# Multi-choice head. # Multi-choice head.
if mpu.is_pipeline_last_stage(): if self.post_process:
self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout) self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
self.multichoice_head = get_linear_layer(args.hidden_size, 1, self.multichoice_head = get_linear_layer(args.hidden_size, 1,
init_method) init_method)
self._multichoice_head_key = 'multichoice_head' self._multichoice_head_key = 'multichoice_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, model_input, attention_mask, tokentype_ids=None): def forward(self, model_input, attention_mask, tokentype_ids=None):
# [batch, choices, sequence] --> [batch * choices, sequence] --> # [batch, choices, sequence] --> [batch * choices, sequence] -->
...@@ -64,22 +75,21 @@ class MultipleChoiceBase(MegatronModule): ...@@ -64,22 +75,21 @@ class MultipleChoiceBase(MegatronModule):
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) attention_mask = attention_mask.view(-1, attention_mask.size(-1))
extended_attention_mask = bert_extended_attention_mask(attention_mask) extended_attention_mask = bert_extended_attention_mask(attention_mask)
kwargs = {} input_ids = model_input
if mpu.is_pipeline_first_stage(): # Do the same as attention_mask for input_ids, tokentype_ids
input_ids = model_input assert len(input_ids.shape) == 3
# Do the same as attention_mask for input_ids, tokentype_ids assert len(tokentype_ids.shape) == 3
assert len(input_ids.shape) == 3 input_ids = input_ids.view(-1, input_ids.size(-1))
assert len(tokentype_ids.shape) == 3 tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
input_ids = input_ids.view(-1, input_ids.size(-1)) position_ids = bert_position_ids(input_ids)
tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
lm_output = self.language_model(
position_ids = bert_position_ids(input_ids) input_ids,
args = [input_ids, position_ids, extended_attention_mask] position_ids,
kwargs['tokentype_ids'] = tokentype_ids extended_attention_mask,
else: tokentype_ids=tokentype_ids
args = [model_input, extended_attention_mask] )
lm_output = self.language_model(*args, **kwargs) if self.post_process:
if mpu.is_pipeline_last_stage():
_, pooled_output = lm_output _, pooled_output = lm_output
multichoice_output = self.multichoice_dropout(pooled_output) multichoice_output = self.multichoice_dropout(pooled_output)
multichoice_logits = self.multichoice_head(multichoice_output) multichoice_logits = self.multichoice_head(multichoice_output)
...@@ -99,7 +109,7 @@ class MultipleChoiceBase(MegatronModule): ...@@ -99,7 +109,7 @@ class MultipleChoiceBase(MegatronModule):
state_dict_[self._language_model_key] \ state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage(): if self.post_process:
state_dict_[self._multichoice_head_key] \ state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict( = self.multichoice_head.state_dict(
destination, prefix, keep_vars) destination, prefix, keep_vars)
...@@ -110,7 +120,7 @@ class MultipleChoiceBase(MegatronModule): ...@@ -110,7 +120,7 @@ class MultipleChoiceBase(MegatronModule):
self.language_model.load_state_dict( self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict) state_dict[self._language_model_key], strict=strict)
if mpu.is_pipeline_last_stage(): if self.post_process:
if self._multichoice_head_key in state_dict: if self._multichoice_head_key in state_dict:
self.multichoice_head.load_state_dict( self.multichoice_head.load_state_dict(
state_dict[self._multichoice_head_key], strict=strict) state_dict[self._multichoice_head_key], strict=strict)
...@@ -118,54 +128,3 @@ class MultipleChoiceBase(MegatronModule): ...@@ -118,54 +128,3 @@ class MultipleChoiceBase(MegatronModule):
print_rank_last('***WARNING*** could not find {} in the checkpoint, ' print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format( 'initializing to random'.format(
self._multichoice_head_key)) self._multichoice_head_key))
class MultipleChoice(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoice, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(MultipleChoice, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class MultipleChoiceFirstStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(MultipleChoiceFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class MultipleChoiceIntermediateStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(MultipleChoiceIntermediateStage, self).forward(
hidden_state,
attention_mask)
class MultipleChoiceLastStage(MultipleChoiceBase):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceLastStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(MultipleChoiceLastStage, self).forward(
hidden_state,
attention_mask)
...@@ -532,12 +532,16 @@ class ParallelTransformer(MegatronModule): ...@@ -532,12 +532,16 @@ class ParallelTransformer(MegatronModule):
def __init__(self, init_method, output_layer_init_method, def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding): self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
self.bf16 = args.bf16 self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations self.checkpoint_activations = args.checkpoint_activations
...@@ -572,15 +576,16 @@ class ParallelTransformer(MegatronModule): ...@@ -572,15 +576,16 @@ class ParallelTransformer(MegatronModule):
# Stage 0: [0, 1] [4, 5] # Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7] # Stage 1: [2, 3] [6, 7]
offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \ args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers) (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else: else:
# Each stage gets a contiguous set of layers. # Each stage gets a contiguous set of layers.
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
if mpu.is_pipeline_last_stage(): if self.post_process:
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
...@@ -615,6 +620,16 @@ class ParallelTransformer(MegatronModule): ...@@ -615,6 +620,16 @@ class ParallelTransformer(MegatronModule):
return hidden_states return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False, encoder_output=None, enc_dec_attn_mask=None): get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
...@@ -628,7 +643,7 @@ class ParallelTransformer(MegatronModule): ...@@ -628,7 +643,7 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with ' \ 'get_key_value does not work with ' \
'activation checkpointing' 'activation checkpointing'
if mpu.is_pipeline_first_stage(): if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float. # If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection: if self.fp32_residual_connection:
...@@ -636,10 +651,13 @@ class ParallelTransformer(MegatronModule): ...@@ -636,10 +651,13 @@ class ParallelTransformer(MegatronModule):
# Otherwise, leave it as is. # Otherwise, leave it as is.
else: else:
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
# See set_input_tensor()
hidden_states = self.input_tensor
if encoder_output is not None: if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous() encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask, attention_mask,
...@@ -664,7 +682,7 @@ class ParallelTransformer(MegatronModule): ...@@ -664,7 +682,7 @@ class ParallelTransformer(MegatronModule):
presents.append(present) presents.append(present)
# Final layer norm. # Final layer norm.
if mpu.is_pipeline_last_stage(): if self.post_process:
# Reverting data format change [s b h] --> [b s h]. # Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states) output = self.final_layernorm(hidden_states)
......
...@@ -22,6 +22,20 @@ from megatron import get_num_microbatches ...@@ -22,6 +22,20 @@ from megatron import get_num_microbatches
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import p2p_communication from megatron import p2p_communication
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_forward_backward_func():
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
...@@ -34,8 +48,12 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -34,8 +48,12 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
timers = get_timers() timers = get_timers()
timers('forward-compute').start() timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor) unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches() output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
......
...@@ -26,9 +26,13 @@ import torch.nn.functional as F ...@@ -26,9 +26,13 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.training import communicate from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.utils import get_ltor_masks_and_position_ids from megatron.p2p_communication import recv_forward, send_forward
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_batch(context_tokens): def get_batch(context_tokens):
"""Generate batch from context tokens.""" """Generate batch from context tokens."""
...@@ -395,55 +399,28 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids, ...@@ -395,55 +399,28 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None, layer_past=None, get_key_value=None,
forward_method_parallel_output=None): forward_method_parallel_output=None):
# Hidden size changes when not using recompute, need to tell communicate() # Hidden size changes when not using recompute, need to tell p2p_communicate
# the correct size # functions the correct size
args = get_args() args = get_args()
orig_seq_length = args.seq_length orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1] args.seq_length = tokens.shape[1]
if not mpu.is_pipeline_first_stage(): input_tensor = recv_forward()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward pass through the model. # Forward pass through the model.
if mpu.is_pipeline_first_stage(): unwrapped_model = unwrap_model(
assert input_tensor is None model, (torchDDP, LocalDDP, Float16Module))
if mpu.is_pipeline_last_stage(): unwrapped_model.set_input_tensor(input_tensor)
output_tensor = model(tokens, position_ids, attention_mask, output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value, get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output) forward_method_parallel_output=forward_method_parallel_output)
else:
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value: if get_key_value:
output_tensor, layer_past = output_tensor output_tensor, layer_past = output_tensor
if not mpu.is_pipeline_last_stage(): send_forward(output_tensor)
communicate(tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
args.seq_length = orig_seq_length args.seq_length = orig_seq_length
if get_key_value: if get_key_value:
......
...@@ -61,10 +61,10 @@ def print_datetime(string): ...@@ -61,10 +61,10 @@ def print_datetime(string):
print_rank_0('[' + string + '] datetime: {} '.format(time_str)) print_rank_0('[' + string + '] datetime: {} '.format(time_str))
def pretrain(train_valid_test_dataset_provider, def pretrain(train_valid_test_dataset_provider,
model_provider, model_provider,
forward_step_func, forward_step_func,
extra_args_provider=None, extra_args_provider=None,
args_defaults={}): args_defaults={}):
"""Main training program. """Main training program.
...@@ -195,8 +195,28 @@ def get_model(model_provider_func): ...@@ -195,8 +195,28 @@ def get_model(model_provider_func):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
# Build model on cpu. # Build model.
model = model_provider_func() if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
if not isinstance(model, list): if not isinstance(model, list):
model = [model] model = [model]
...@@ -231,7 +251,7 @@ def get_model(model_provider_func): ...@@ -231,7 +251,7 @@ def get_model(model_provider_func):
process_group=mpu.get_data_parallel_group()) process_group=mpu.get_data_parallel_group())
for model_module in model] for model_module in model]
return model return model
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
model = [LocalDDP(model_module, model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32, args.accumulate_allreduce_grads_in_fp32,
...@@ -651,16 +671,16 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -651,16 +671,16 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) lr_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time)) print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit() sys.exit()
# Exiting based on iterations # Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) lr_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration)) print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit() sys.exit()
......
...@@ -17,56 +17,30 @@ ...@@ -17,56 +17,30 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import (BertModel, from megatron.model import BertModel
BertModelFirstStage,
BertModelIntermediateStage,
BertModelLastStage)
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
args = get_args() args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0 num_tokentypes = 2 if args.bert_binary_head else 0
def model_provider_pipelined(): model = BertModel(
# Determine model based on position of stage in pipeline. num_tokentypes=num_tokentypes,
if mpu.is_pipeline_first_stage(): add_binary_head=args.bert_binary_head,
model = BertModelFirstStage( parallel_output=True,
num_tokentypes=num_tokentypes) pre_process=pre_process,
elif mpu.is_pipeline_last_stage(): post_process=post_process)
model = BertModelLastStage(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True)
else:
model = BertModelIntermediateStage(
num_tokentypes=num_tokentypes)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else:
model = BertModel(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True)
return model return model
...@@ -96,7 +70,33 @@ def get_batch(data_iterator): ...@@ -96,7 +70,33 @@ def get_batch(data_iterator):
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model, input_tensor): def loss_func(loss_mask, sentence_order, output_tensor):
lm_loss_, sop_logits = output_tensor
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0],
'sop loss': averaged_losses[1]}
else:
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss])
return loss, {'lm loss': averaged_losses[0]}
def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -111,46 +111,10 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -111,46 +111,10 @@ def forward_step(data_iterator, model, input_tensor):
types = None types = None
# Forward pass through the model. # Forward pass through the model.
if mpu.is_pipeline_first_stage(): output_tensor = model(tokens, padding_mask, tokentype_ids=types,
assert input_tensor is None lm_labels=lm_labels)
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, padding_mask, tokentype_ids=types, return output_tensor, partial(loss_func, loss_mask, sentence_order)
lm_labels=lm_labels)
else:
output_tensor = model(tokens, padding_mask, tokentype_ids=types)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask)
if mpu.is_pipeline_last_stage():
lm_loss_, sop_logits = output_tensor
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0],
'sop loss': averaged_losses[1]}
else:
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss])
return loss, {'lm loss': averaged_losses[0]}
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
......
...@@ -16,50 +16,28 @@ ...@@ -16,50 +16,28 @@
"""Pretrain GPT""" """Pretrain GPT"""
import torch import torch
from functools import partial
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import (GPTModel, from megatron.model import GPTModel
GPTModelFirstStage,
GPTModelIntermediateStage,
GPTModelLastStage)
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
print_rank_0('building GPT model ...') print_rank_0('building GPT model ...')
model = GPTModel(
def model_provider_pipelined(): num_tokentypes=0,
# Determine model based on position of stage in pipeline. parallel_output=True,
if mpu.is_pipeline_first_stage(): pre_process=pre_process,
model = GPTModelFirstStage(num_tokentypes=0) post_process=post_process
elif mpu.is_pipeline_last_stage(): )
model = GPTModelLastStage(
num_tokentypes=0, parallel_output=True)
else:
model = GPTModelIntermediateStage(
num_tokentypes=0)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else:
model = GPTModel(num_tokentypes=0, parallel_output=True)
return model return model
...@@ -94,8 +72,18 @@ def get_batch(data_iterator): ...@@ -94,8 +72,18 @@ def get_batch(data_iterator):
return tokens, labels, loss_mask, attention_mask, position_ids return tokens, labels, loss_mask, attention_mask, position_ids
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
def forward_step(data_iterator, model, input_tensor): # Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -106,31 +94,10 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -106,31 +94,10 @@ def forward_step(data_iterator, model, input_tensor):
data_iterator) data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
# Forward pass through the model. output_tensor = model(tokens, position_ids, attention_mask,
if mpu.is_pipeline_first_stage(): labels=labels)
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask)
elif mpu.is_pipeline_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask, labels=labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]} return output_tensor, partial(loss_func, loss_mask)
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
......
...@@ -17,13 +17,14 @@ ...@@ -17,13 +17,14 @@
import os import os
import time import time
from functools import partial
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_last, is_last_rank from megatron import print_rank_last, is_last_rank
from megatron import mpu from megatron import mpu
from megatron.training import communicate from megatron.schedules import get_forward_backward_func
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch from tasks.finetune_utils import process_batch
...@@ -38,7 +39,7 @@ def accuracy_func_provider(single_dataset_provider): ...@@ -38,7 +39,7 @@ def accuracy_func_provider(single_dataset_provider):
for datapath in datapaths: for datapath in datapaths:
dataset = single_dataset_provider(datapath) dataset = single_dataset_provider(datapath)
dataloader = build_data_loader( dataloader = build_data_loader(
dataset, args.micro_batch_size, num_workers=args.num_workers, dataset, args.orig_micro_batch_size, num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1)) drop_last=(mpu.get_data_parallel_world_size() > 1))
dataloaders.append((dataset.dataset_name, dataloader)) dataloaders.append((dataset.dataset_name, dataloader))
...@@ -73,14 +74,66 @@ def accuracy_func_provider(single_dataset_provider): ...@@ -73,14 +74,66 @@ def accuracy_func_provider(single_dataset_provider):
return metrics_func return metrics_func
def calculate_correct_answers(name, model, dataloader, def calculate_correct_answers(name, model, dataloader,
epoch, output_predictions): epoch, output_predictions):
"""Calculate correct over total answers and return prediction if the """Calculate correct over total answers and return prediction if the
`output_predictions` is true.""" `output_predictions` is true."""
args = get_args() args = get_args()
forward_backward_func = get_forward_backward_func()
start_time = time.time() start_time = time.time()
model.eval() for m in model:
saved_batch_size = args.micro_batch_size m.eval()
saved_micro_batch_size = args.micro_batch_size
saved_global_batch_size = args.global_batch_size
ds = dataloader.dataset
if hasattr(ds, 'sample_multiplier'):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
sample_multiplier = ds.sample_multiplier
else:
sample_multiplier = 1
micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel
def loss_func(output_predictions, labels, output_tensor):
logits = output_tensor
loss_dict = {}
# Add output predictions.
if output_predictions:
assert False
loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
logits.float()).data.cpu().numpy().tolist()
loss_dict['labels'] = labels.data.cpu().numpy().tolist()
loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels)
# Add to the counters.
loss_dict['total'] = labels.size(0)
loss_dict['correct'] = corrects.sum().item()
return 0, loss_dict
# defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_)
# Forward model.
args = get_args()
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
return output_tensor, partial(loss_func, output_predictions, labels)
with torch.no_grad(): with torch.no_grad():
# For all the batches in the dataset. # For all the batches in the dataset.
total = 0 total = 0
...@@ -92,60 +145,30 @@ def calculate_correct_answers(name, model, dataloader, ...@@ -92,60 +145,30 @@ def calculate_correct_answers(name, model, dataloader,
labels = [] labels = []
ids = [] ids = []
for _, batch in enumerate(dataloader): for _, batch in enumerate(dataloader):
# Run the model forward.
tokens, types, labels_, attention_mask = process_batch(batch)
# For evaluation only mode we use drop_last = False to get all the # For evaluation only mode we use drop_last = False to get all the
# samples, which means we might not have a full batch, so we # samples, which means we might not have a full batch, so we
# adjust batch_size here to actual batch size of data # adjust batch_size here to actual batch size of data
actual_batch_size = len(labels_) actual_batch_size = len(batch['label'])
# ... applying sample_multiplier if necessary # ... applying sample_multiplier if necessary
ds = dataloader.dataset args.micro_batch_size = actual_batch_size * sample_multiplier
if hasattr(ds, 'sample_multiplier'): args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches
actual_batch_size *= ds.sample_multiplier
args.micro_batch_size = actual_batch_size
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward model. loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
if mpu.is_pipeline_first_stage(): optimizer=None, timers=None, forward_only=True)
assert input_tensor is None
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_pipeline_last_stage():
logits = output_tensor
# Add output predictions. for loss_dict in loss_dicts:
if output_predictions: if output_predictions:
softmaxes.extend(torch.nn.Softmax(dim=-1)( softmaxes.extend(loss_dict['softmaxes'])
logits.float()).data.cpu().numpy().tolist()) labels.extend(loss_dict['labels'])
labels.extend(labels_.data.cpu().numpy().tolist()) ids.extend(loss_dict['ids'])
ids.extend(batch['uid'].cpu().numpy().tolist()) total += loss_dict['total']
# Compute the correct answers. correct += loss_dict['correct']
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels_)
# Add to the counters. for m in model:
total += labels_.size(0) m.train()
correct += corrects.sum().item() args.micro_batch_size = saved_micro_batch_size
else: args.global_batch_size = saved_global_batch_size
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
model.train()
args.micro_batch_size = saved_batch_size
# Reduce. # Reduce.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment