Commit aed2f75e authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into github-main

parents 8aa4619f f32a638d
image: gitlab-master.nvidia.com/dl/dgx/pytorch:20.12-py3-devel
test:
script:
- pytest --junitxml=report.xml tests
artifacts:
when: always
reports:
junit: report.xml
\ No newline at end of file
...@@ -29,7 +29,7 @@ The following applies to all files unless otherwise noted: ...@@ -29,7 +29,7 @@ The following applies to all files unless otherwise noted:
-- --
This repository also contains code from Hugging Face Inc., Google Research, This repository also contains code from Hugging Face Inc., Google Research,
and Facebook (from their Fairseq project). Files from these Facebook (from their Fairseq project), and Philip Popien. Files from these
organizations have notices at the top of each file. Below are licenses organizations have notices at the top of each file. Below are licenses
used in those files, as indicated. used in those files, as indicated.
...@@ -262,3 +262,4 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER ...@@ -262,3 +262,4 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
[Megatron](https://arxiv.org/pdf/1909.08053.pdf) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel (tensor and pipeline), and multi-node pre-training of [GPT](https://arxiv.org/abs/2005.14165) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision. Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf) and [2](https://arxiv.org/pdf/2104.04473.pdf)) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel (tensor and pipeline), and multi-node pre-training of [GPT](https://arxiv.org/abs/2005.14165) and [BERT](https://arxiv.org/pdf/1810.04805.pdf) using mixed precision.
Below are some of the projects where we have directly used Megatron: Below are some of the projects where we have directly used Megatron:
* [BERT and GPT Studies Using Megatron](https://arxiv.org/pdf/1909.08053.pdf) * [BERT and GPT Studies Using Megatron](https://arxiv.org/pdf/1909.08053.pdf)
...@@ -8,15 +8,15 @@ Below are some of the projects where we have directly used Megatron: ...@@ -8,15 +8,15 @@ Below are some of the projects where we have directly used Megatron:
* [Local Knowledge Powered Conversational Agents](https://arxiv.org/abs/2010.10150) * [Local Knowledge Powered Conversational Agents](https://arxiv.org/abs/2010.10150)
* [MEGATRON-CNTRL: Controllable Story Generation with External Knowledge Using Large-Scale Language Models](https://www.aclweb.org/anthology/2020.emnlp-main.226.pdf) * [MEGATRON-CNTRL: Controllable Story Generation with External Knowledge Using Large-Scale Language Models](https://www.aclweb.org/anthology/2020.emnlp-main.226.pdf)
* [RACE Reading Comprehension Dataset Leaderboard](http://www.qizhexie.com/data/RACE_leaderboard.html) * [RACE Reading Comprehension Dataset Leaderboard](http://www.qizhexie.com/data/RACE_leaderboard.html)
* [Scaling Language Model Training to a Trillion Parameters Using Megatron](https://arxiv.org/pdf/2104.04473.pdf)
* [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf) * [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf)
Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. The table below shows the model configurations along with the achieved FLOPs per second (both per GPU and aggregate over all GPUs). Note that the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging. Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specifc model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. The table below shows the model configurations along with the achieved FLOPs (both per GPU and aggregate over all GPUs). Note that the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging.
![Cases](images/cases_jan2021.png) ![Cases](images/cases_april2021.png)
The following figures show achieved percentage of theoretical peak FLOPs and achieved aggregate petaFLOPs per second as a function of number of GPUs. All the cases from 1 billion to 1 trillion achieve more than 41% half precision utilization, which is high for an end-to-end application. We observe that initially as the model parallel size increases, utilization slightly decreases; as hidden size increases for larger models, utilization starts increasing and reaches 49% for the largest model. We also note that achieved aggregate petaFLOPs per second across all GPUs increases almost linearly with number of GPUs, demonstrating good weak scaling. All the cases from 1 billion to 1 trillion parameters achieve more than 43% half precision utilization, which is high for an end-to-end application. We observe that initially the utilization remains constant but as hidden size increases for larger models, utilization starts increasing and reaches 52% for the largest model. We also note that achieved aggregate petaFLOPs across all GPUs increases almost linearly with number of GPUs, demonstrating good weak scaling.
![Model Parallel Scaling](images/scaling.png)
# Contents # Contents
* [Contents](#contents) * [Contents](#contents)
...@@ -370,11 +370,11 @@ python tools/create_doc_index.py \ ...@@ -370,11 +370,11 @@ python tools/create_doc_index.py \
We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning. We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning.
<!-- Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Currently only tensor model parallelism is supported on input and pipeline model parallelsim on the output. This example reads in a model with 2-way tensor model parallelism and writes out a model with 2-way pipeline model parallelism.
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.
<pre> <pre>
TENSOR_MODEL_PARALLEL_SIZE=2 TENSOR_MODEL_PARALLEL_SIZE=2
TARGET_PIPELINE_MODEL_PARALLEL_SIZE=2
VOCAB_FILE=bert-vocab.txt VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m CHECKPOINT_PATH=checkpoints/bert_345m
...@@ -382,6 +382,8 @@ CHECKPOINT_PATH=checkpoints/bert_345m ...@@ -382,6 +382,8 @@ CHECKPOINT_PATH=checkpoints/bert_345m
WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \ --model-type BERT \
--tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \ --tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
--pipeline-model-parallel-size 1 \
--target-pipeline-model-parallel-size $TARGET_PIPELINE_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \ --tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--num-layers 24 \ --num-layers 24 \
...@@ -390,9 +392,10 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ ...@@ -390,9 +392,10 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--load $CHECKPOINT_PATH --load $CHECKPOINT_PATH
--save $CHECKPOINT_PATH/merged
</pre> </pre>
-->
Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts. Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.
## GPT Text Generation ## GPT Text Generation
......
#!/bin/bash
# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)
RANK=0
WORLD_SIZE=1
# Wikipedia data can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path to store embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>
python tools/create_doc_index.py \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--micro-batch-size 128 \
--checkpoint-activations \
--seq-length 512 \
--retriever-seq-length 256 \
--max-position-embeddings 512 \
--load ${CHECKPOINT_PATH} \
--evidence-data-path ${EVIDENCE_DATA_DIR} \
--embedding-path ${EMBEDDING_PATH} \
--indexer-log-interval 1000 \
--indexer-batch-size 128 \
--vocab-file bert-vocab.txt \
--num-workers 2 \
--fp16
#!/bin/bash
# Evaluate natural question test data given Wikipedia embeddings and pretrained
# ICT model
# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path of the embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>
QA_FILE=<Path of the natural question test dataset>
python tasks/main.py \
--task ICT-ZEROSHOT-NQ \
--tokenizer-type BertWordPieceLowerCase \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--micro-batch-size 128 \
--checkpoint-activations \
--seq-length 512 \
--max-position-embeddings 512 \
--load ${CHECKPOINT_PATH} \
--evidence-data-path ${EVIDENCE_DATA_DIR} \
--embedding-path ${EMBEDDING_PATH} \
--retriever-seq-length 256 \
--vocab-file bert-vocab.txt\
--qa-data-test ${QA_FILE} \
--num-workers 2 \
--faiss-use-gpu \
--retriever-report-topk-accuracies 1 5 20 100 \
--fp16
...@@ -46,6 +46,7 @@ options=" \ ...@@ -46,6 +46,7 @@ options=" \
--weight-decay 0.1 \ --weight-decay 0.1 \
--adam-beta1 0.9 \ --adam-beta1 0.9 \
--adam-beta2 0.95 \ --adam-beta2 0.95 \
--init-method-std 0.006 \
--tensorboard-dir <TENSORBOARD DIRECTORY> \ --tensorboard-dir <TENSORBOARD DIRECTORY> \
--fp16 \ --fp16 \
--checkpoint-activations " --checkpoint-activations "
......
#! /bin/bash
# Runs the "217M" parameter biencoder model for ICT retriever
RANK=0
WORLD_SIZE=1
PRETRAINED_BERT_PATH=<Specify path of pretrained BERT model>
TEXT_DATA_PATH=<Specify path and file prefix of the text data>
TITLE_DATA_PATH=<Specify path and file prefix od the titles>
CHECKPOINT_PATH=<Specify path>
python pretrain_ict.py \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--micro-batch-size 32 \
--seq-length 256 \
--max-position-embeddings 512 \
--train-iters 100000 \
--vocab-file bert-vocab.txt \
--tokenizer-type BertWordPieceLowerCase \
--DDP-impl torch \
--bert-load ${PRETRAINED_BERT_PATH} \
--log-interval 100 \
--eval-interval 1000 \
--eval-iters 10 \
--retriever-report-topk-accuracies 1 5 10 20 100 \
--retriever-score-scaling \
--load $CHECKPOINT_PATH \
--save $CHECKPOINT_PATH \
--data-path ${TEXT_DATA_PATH} \
--titles-data-path ${TITLE_DATA_PATH} \
--lr 0.0001 \
--lr-decay-style linear \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--lr-warmup-fraction 0.01 \
--save-interval 4000 \
--exit-interval 8000 \
--query-in-block-prob 0.1 \
--fp16
default: cases.png scaling-mp.png scaling-dp.png
# for some reason the size option to convert in scaling.tex doesn't work, manually do it after
cases.png scaling-mp.png scaling-dp.png: tables.tex
latex --shell-escape $<
convert tables-1.png -resize 650 cases.png
convert tables-2.png -resize 600 scaling-mp.png
convert tables-3.png -resize 350 scaling-dp.png
clean:
rm -rf *.aux *.log *.dvi *.ps
rm -rf tables-*.png
\documentclass[multi,convert]{standalone}
\usepackage{multirow}
\standaloneenv{tabular}
\begin{document}
\begin{tabular}{cccccc}
Case & Hidden Size & Attention Heads & Layers & Parameters (billions) & Model Parallel Partitions \\
\hline
1B & 1920 & 15 & 24 & 1.16 & 1 \\
2B & 2304 & 18 & 30 & 2.03 & 2 \\
4B & 3072 & 24 & 36 & 4.24 & 4 \\
8B & 4096 & 32 & 42 & 8.67 & 8 \\
\end{tabular}
\begin{tabular}{cc|ccc|ccc}
& & \multicolumn{3}{c|}{\textbf{DGX-2 (V100) batch size 8}} & \multicolumn{3}{c}{\textbf{DGX-A100 batch size 16}} \\
\hline
\multirow{2}{*}{Case} & Number of & Iteration & \multirow{2}{*}{Scaling} & TeraFLOPs & Iteration & \multirow{2}{*}{Scaling} & TeraFLOPs \\
& GPUs & Time (ms) & & per GPU & Time (ms) & & per GPU \\
\hline
1B & 1 & 1121 & 100.0\% & 71.9 & 1076 & 100\% & 149.8 \\
2B & 2 & 1093 & 89.6\% & 64.2 & 1026 & 91.7\% & 136.8 \\
4B & 4 & 1238 & 82.5\% & 58.5 & 1162 & 84.5\% & 124.7 \\
8B & 8 & 1407 & 74.3\% & 52.2 & 1343 & 74.7\% & 109.3 \\
\end{tabular}
\begin{tabular}{cc|ccc}
& & \multicolumn{3}{c}{\textbf{DGX-A100 batch size 2048}} \\
\hline
\multirow{2}{*}{Case} & Number of & Iteration & \multirow{2}{*}{Scaling} & TeraFLOPs \\
& GPUs & Time (ms) & & per GPU \\
\hline
1B & 128 & 1153 & 93.3\% & 139.8 \\
2B & 256 & 1101 & 85.5\% & 127.5 \\
4B & 512 & 1242 & 79.0\% & 116.7 \\
8B & 1024 & 1380 & 72.7\% & 106.5 \\
\end{tabular}
\end{document}
This diff is collapsed.
...@@ -21,18 +21,20 @@ import sys ...@@ -21,18 +21,20 @@ import sys
import numpy as np import numpy as np
import torch import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from megatron import mpu, get_args, update_num_microbatches from megatron import (get_args,
from megatron import get_args mpu,
from megatron import print_rank_0 print_rank_0,
update_num_microbatches,
utils)
_CHECKPOINT_VERSION = None _CHECKPOINT_VERSION = None
def set_checkpoint_version(value): def set_checkpoint_version(value):
global _CHECKPOINT_VERSION global _CHECKPOINT_VERSION
assert _CHECKPOINT_VERSION is None, \ if _CHECKPOINT_VERSION is not None:
"checkpoint version already set" assert _CHECKPOINT_VERSION == value, \
"checkpoint versions do not match"
_CHECKPOINT_VERSION = value _CHECKPOINT_VERSION = value
def get_checkpoint_version(): def get_checkpoint_version():
...@@ -59,9 +61,10 @@ def check_checkpoint_args(checkpoint_args): ...@@ -59,9 +61,10 @@ def check_checkpoint_args(checkpoint_args):
_compare('hidden_size') _compare('hidden_size')
_compare('num_attention_heads') _compare('num_attention_heads')
_compare('max_position_embeddings') _compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by') if args.vocab_file:
_compare('padded_vocab_size') _compare('make_vocab_size_divisible_by')
_compare('tokenizer_type') _compare('padded_vocab_size')
_compare('tokenizer_type')
if get_checkpoint_version() < 3.0: if get_checkpoint_version() < 3.0:
_compare('tensor_model_parallel_size', _compare('tensor_model_parallel_size',
old_arg_name='model_parallel_size') old_arg_name='model_parallel_size')
...@@ -108,21 +111,24 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -108,21 +111,24 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args = get_args() args = get_args()
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
if isinstance(model, torchDDP): model = utils.unwrap_model(model)
model = model.module
if torch.distributed.get_rank() == 0: print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
print('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save))
iteration, args.save), flush=True)
if mpu.get_data_parallel_rank() == 0: if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
state_dict['args'] = args state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0 state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint() if len(model) == 1:
state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
# Optimizer stuff. # Optimizer stuff.
if not args.no_save_optim: if not args.no_save_optim:
...@@ -146,26 +152,98 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -146,26 +152,98 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
torch.save(state_dict, checkpoint_name) torch.save(state_dict, checkpoint_name)
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
torch.distributed.barrier() if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0: torch.distributed.barrier()
print(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True) print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
# And update the latest iteration # And update the latest iteration
if torch.distributed.get_rank() == 0: if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save) tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f: with open(tracker_filename, 'w') as f:
f.write(str(iteration)) f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): # Wait so everyone is done (not necessary)
"""Load a model checkpoint and return the iteration.""" if torch.distributed.is_initialized():
torch.distributed.barrier()
def _transpose_first_dim(t, num_splits, num_splits_first, model):
input_shape = t.size()
# We use a self_attention module but the values extracted aren't
# specific to self attention so should work for cross attention as well
while hasattr(model, 'module'):
model = model.module
attention_module = model.language_model.encoder.layers[0].self_attention
hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
if num_splits_first:
"""[num_splits * np * hn, h]
-->(view) [num_splits, np, hn, h]
-->(tranpose) [np, num_splits, hn, h]
-->(view) [np * num_splits * hn, h] """
intermediate_shape = \
(num_splits, num_attention_heads_per_partition,
hidden_size_per_attention_head) + input_shape[1:]
t = t.view(*intermediate_shape)
t = t.transpose(0, 1).contiguous()
else:
"""[np * hn * num_splits, h]
-->(view) [np, hn, num_splits, h]
-->(tranpose) [np, num_splits, hn, h]
-->(view) [np * num_splits * hn, h] """
intermediate_shape = \
(num_attention_heads_per_partition,
hidden_size_per_attention_head, num_splits) +\
input_shape[1:]
t = t.view(*intermediate_shape)
t = t.transpose(1, 2).contiguous()
t = t.view(*input_shape)
return t
def fix_query_key_value_ordering(model, checkpoint_version):
"""Fix up query/key/value matrix ordering if checkpoint
version is smaller than 2.0
"""
if checkpoint_version < 2.0:
for name, param in model.named_parameters():
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 3, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 3, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
if name.endswith(('.key_value.weight', '.key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 2, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 2, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model.
"""
args = get_args() args = get_args()
load_dir = getattr(args, load_arg) load_dir = getattr(args, load_arg)
if isinstance(model, torchDDP): model = utils.unwrap_model(model)
model = model.module
# Read the tracker file and set the iteration. # Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir) tracker_filename = get_checkpoint_tracker_filename(load_dir)
...@@ -197,9 +275,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -197,9 +275,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
# Checkpoint. # Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release) checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if torch.distributed.get_rank() == 0: print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
print(' loading checkpoint from {} at iteration {}'.format(
args.load, iteration), flush=True)
# Load the checkpoint. # Load the checkpoint.
try: try:
...@@ -252,7 +328,17 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -252,7 +328,17 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print_rank_0('could not find arguments in the checkpoint ...') print_rank_0('could not find arguments in the checkpoint ...')
# Model. # Model.
model.load_state_dict(state_dict['model']) if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed
checkpoint_version = get_checkpoint_version()
print_rank_0(f' checkpoint version {checkpoint_version}')
fix_query_key_value_ordering(model, checkpoint_version)
# Optimizer. # Optimizer.
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
...@@ -275,58 +361,64 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -275,58 +361,64 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
np.random.set_state(state_dict['np_rng_state']) np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state']) torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state']) torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states( mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states']) state_dict['rng_tracker_states'])
except KeyError: except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent ' 'Specify --no-load-rng or --finetune to prevent '
'attempting to load the optimizer state, ' 'attempting to load the rng state, '
'exiting ...'.format(checkpoint_name)) 'exiting ...'.format(checkpoint_name))
sys.exit() sys.exit()
torch.distributed.barrier() # Some utilities want to load a checkpoint without distributed being initialized
if torch.distributed.get_rank() == 0: if torch.distributed.is_initialized():
print(' successfully loaded checkpoint from {} at iteration {}'.format( torch.distributed.barrier()
args.load, iteration), flush=True)
print_rank_0(f' successfully loaded checkpoint from {args.load} '
f'at iteration {iteration}')
return iteration return iteration
def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False): def load_biencoder_checkpoint(model, only_query_model=False,
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" only_context_model=False, custom_load_path=None):
"""
selectively load retrieval models for indexing/retrieving
from saved checkpoints
"""
args = get_args() args = get_args()
if isinstance(model, torchDDP): model = utils.unwrap_model(model)
model = model.module
load_path = args.load if from_realm_chkpt else args.ict_load load_path = custom_load_path if custom_load_path is not None else args.load
tracker_filename = get_checkpoint_tracker_filename(load_path) tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f: with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip()) iteration = int(f.read().strip())
# assert iteration > 0
checkpoint_name = get_checkpoint_name(load_path, iteration, False) checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model'] ret_state_dict = state_dict['model']
if from_realm_chkpt and mpu.get_data_parallel_rank() == 0:
print(" loading ICT state dict from REALM", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model: if only_query_model:
ict_state_dict.pop('context_model') ret_state_dict.pop('context_model')
if only_block_model: if only_context_model:
ict_state_dict.pop('question_model') ret_state_dict.pop('query_model')
model.load_state_dict(ict_state_dict) assert len(model) == 1
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier() torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name)) print(' successfully loaded {}'.format(checkpoint_name))
return model return model
"""AutoAugment data augmentation policy for ImageNet.
-- Begin license text.
MIT License
Copyright (c) 2018 Philip Popien
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-- End license text.
Code adapted from https://github.com/DeepVoltaire/AutoAugment.
This module implements the fixed AutoAugment data augmentation policy for ImageNet provided in
Appendix A, Table 9 of reference [1]. It does not include any of the search code for augmentation
policies.
Reference:
[1] https://arxiv.org/abs/1805.09501
"""
import random
import numpy as np
from PIL import Image
from PIL import ImageEnhance
from PIL import ImageOps
_MAX_LEVEL = 10 # Maximum integer strength of an augmentation, if applicable.
class ImageNetPolicy:
"""Definition of an ImageNetPolicy.
Implements a fixed AutoAugment data augmentation policy targeted at
ImageNet training by randomly applying at runtime one of the 25 pre-defined
data augmentation sub-policies provided in Reference [1].
Usage example as a Pytorch Transform:
>>> transform=transforms.Compose([transforms.Resize(256),
>>> ImageNetPolicy(),
>>> transforms.ToTensor()])
"""
def __init__(self, fillcolor=(128, 128, 128)):
"""Initialize an ImageNetPolicy.
Args:
fillcolor (tuple): RGB color components of the color to be used for
filling when needed (default: (128, 128, 128), which
corresponds to gray).
"""
# Instantiate a list of sub-policies.
# Each entry of the list is a SubPolicy which consists of
# two augmentation operations,
# each of those parametrized as operation, probability, magnitude.
# Those two operations are applied sequentially on the image upon call.
self.policies = [
SubPolicy("posterize", 0.4, 8, "rotate", 0.6, 9, fillcolor),
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
SubPolicy("posterize", 0.6, 7, "posterize", 0.6, 6, fillcolor),
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
SubPolicy("equalize", 0.4, 4, "rotate", 0.8, 8, fillcolor),
SubPolicy("solarize", 0.6, 3, "equalize", 0.6, 7, fillcolor),
SubPolicy("posterize", 0.8, 5, "equalize", 1.0, 2, fillcolor),
SubPolicy("rotate", 0.2, 3, "solarize", 0.6, 8, fillcolor),
SubPolicy("equalize", 0.6, 8, "posterize", 0.4, 6, fillcolor),
SubPolicy("rotate", 0.8, 8, "color", 0.4, 0, fillcolor),
SubPolicy("rotate", 0.4, 9, "equalize", 0.6, 2, fillcolor),
SubPolicy("equalize", 0.0, 7, "equalize", 0.8, 8, fillcolor),
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
SubPolicy("rotate", 0.8, 8, "color", 1.0, 2, fillcolor),
SubPolicy("color", 0.8, 8, "solarize", 0.8, 7, fillcolor),
SubPolicy("sharpness", 0.4, 7, "invert", 0.6, 8, fillcolor),
SubPolicy("shearX", 0.6, 5, "equalize", 1.0, 9, fillcolor),
SubPolicy("color", 0.4, 0, "equalize", 0.6, 3, fillcolor),
SubPolicy("equalize", 0.4, 7, "solarize", 0.2, 4, fillcolor),
SubPolicy("solarize", 0.6, 5, "autocontrast", 0.6, 5, fillcolor),
SubPolicy("invert", 0.6, 4, "equalize", 1.0, 8, fillcolor),
SubPolicy("color", 0.6, 4, "contrast", 1.0, 8, fillcolor),
SubPolicy("equalize", 0.8, 8, "equalize", 0.6, 3, fillcolor),
]
def __call__(self, img):
"""Define call method for ImageNetPolicy class."""
policy_idx = random.randint(0, len(self.policies) - 1)
return self.policies[policy_idx](img)
def __repr__(self):
"""Define repr method for ImageNetPolicy class."""
return "ImageNetPolicy"
class SubPolicy:
"""Definition of a SubPolicy.
A SubPolicy consists of two augmentation operations,
each of those parametrized as operation, probability, magnitude.
The two operations are applied sequentially on the image upon call.
"""
def __init__(
self,
operation1,
probability1,
magnitude_idx1,
operation2,
probability2,
magnitude_idx2,
fillcolor,
):
"""Initialize a SubPolicy.
Args:
operation1 (str): Key specifying the first augmentation operation.
There are fourteen key values altogether (see supported_ops below
listing supported operations). probability1 (float): Probability
within [0., 1.] of applying the first augmentation operation.
magnitude_idx1 (int): Integer specifiying the strength of the first
operation as an index further used to derive the magnitude from a
range of possible values.
operation2 (str): Key specifying the second augmentation operation.
probability2 (float): Probability within [0., 1.] of applying the
second augmentation operation.
magnitude_idx2 (int): Integer specifiying the strength of the
second operation as an index further used to derive the magnitude
from a range of possible values.
fillcolor (tuple): RGB color components of the color to be used for
filling.
Returns:
"""
# List of supported operations for operation1 and operation2.
supported_ops = [
"shearX",
"shearY",
"translateX",
"translateY",
"rotate",
"color",
"posterize",
"solarize",
"contrast",
"sharpness",
"brightness",
"autocontrast",
"equalize",
"invert",
]
assert (operation1 in supported_ops) and (
operation2 in supported_ops
), "SubPolicy:one of oper1 or oper2 refers to an unsupported operation."
assert (
0.0 <= probability1 <= 1.0 and 0.0 <= probability2 <= 1.0
), "SubPolicy: prob1 and prob2 should be within [0., 1.]."
assert (
isinstance(magnitude_idx1, int) and 0 <= magnitude_idx1 <= 10
), "SubPolicy: idx1 should be specified as an integer within [0, 10]."
assert (
isinstance(magnitude_idx2, int) and 0 <= magnitude_idx2 <= 10
), "SubPolicy: idx2 should be specified as an integer within [0, 10]."
# Define a dictionary where each key refers to a specific type of
# augmentation and the corresponding value is a range of ten possible
# magnitude values for that augmentation.
num_levels = _MAX_LEVEL + 1
ranges = {
"shearX": np.linspace(0, 0.3, num_levels),
"shearY": np.linspace(0, 0.3, num_levels),
"translateX": np.linspace(0, 150 / 331, num_levels),
"translateY": np.linspace(0, 150 / 331, num_levels),
"rotate": np.linspace(0, 30, num_levels),
"color": np.linspace(0.0, 0.9, num_levels),
"posterize": np.round(np.linspace(8, 4, num_levels), 0).astype(
np.int
),
"solarize": np.linspace(256, 0, num_levels), # range [0, 256]
"contrast": np.linspace(0.0, 0.9, num_levels),
"sharpness": np.linspace(0.0, 0.9, num_levels),
"brightness": np.linspace(0.0, 0.9, num_levels),
"autocontrast": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
"equalize": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
"invert": [0]
* num_levels, # This augmentation doesn't use magnitude parameter.
}
def rotate_with_fill(img, magnitude):
"""Define rotation transformation with fill.
The input image is first rotated, then it is blended together with
a gray mask of the same size. Note that fillcolor as defined
elsewhere in this module doesn't apply here.
Args:
magnitude (float): rotation angle in degrees.
Returns:
rotated_filled (PIL Image): rotated image with gray filling for
disoccluded areas unveiled by the rotation.
"""
rotated = img.convert("RGBA").rotate(magnitude)
rotated_filled = Image.composite(
rotated, Image.new("RGBA", rotated.size, (128,) * 4), rotated
)
return rotated_filled.convert(img.mode)
# Define a dictionary of augmentation functions where each key refers
# to a specific type of augmentation and the corresponding value defines
# the augmentation itself using a lambda function.
# pylint: disable=unnecessary-lambda
func_dict = {
"shearX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
Image.BICUBIC,
fillcolor=fillcolor,
),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1,
0,
magnitude * img.size[0] * random.choice([-1, 1]),
0,
1,
0,
),
fillcolor=fillcolor,
),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(
1,
0,
0,
0,
1,
magnitude * img.size[1] * random.choice([-1, 1]),
),
fillcolor=fillcolor,
),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
1 + magnitude * random.choice([-1, 1])
),
"posterize": lambda img, magnitude: ImageOps.posterize(
img, magnitude
),
"solarize": lambda img, magnitude: ImageOps.solarize(
img, magnitude
),
"contrast": lambda img, magnitude: ImageEnhance.Contrast(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"brightness": lambda img, magnitude: ImageEnhance.Brightness(
img
).enhance(1 + magnitude * random.choice([-1, 1])),
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img),
}
# Store probability, function and magnitude of the first augmentation
# for the sub-policy.
self.probability1 = probability1
self.operation1 = func_dict[operation1]
self.magnitude1 = ranges[operation1][magnitude_idx1]
# Store probability, function and magnitude of the second augmentation
# for the sub-policy.
self.probability2 = probability2
self.operation2 = func_dict[operation2]
self.magnitude2 = ranges[operation2][magnitude_idx2]
def __call__(self, img):
"""Define call method for SubPolicy class."""
# Randomly apply operation 1.
if random.random() < self.probability1:
img = self.operation1(img, self.magnitude1)
# Randomly apply operation 2.
if random.random() < self.probability2:
img = self.operation2(img, self.magnitude2)
return img
...@@ -36,13 +36,14 @@ class BertDataset(Dataset): ...@@ -36,13 +36,14 @@ class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed): max_seq_length, short_seq_prob, seed, binary_head):
# Params to store. # Params to store.
self.name = name self.name = name
self.seed = seed self.seed = seed
self.masked_lm_prob = masked_lm_prob self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.binary_head = binary_head
# Dataset. # Dataset.
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
...@@ -55,7 +56,8 @@ class BertDataset(Dataset): ...@@ -55,7 +56,8 @@ class BertDataset(Dataset):
self.max_seq_length, self.max_seq_length,
short_seq_prob, short_seq_prob,
self.seed, self.seed,
self.name) self.name,
self.binary_head)
# Vocab stuff. # Vocab stuff.
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -81,7 +83,8 @@ class BertDataset(Dataset): ...@@ -81,7 +83,8 @@ class BertDataset(Dataset):
self.vocab_id_to_token_dict, self.vocab_id_to_token_dict,
self.cls_id, self.sep_id, self.cls_id, self.sep_id,
self.mask_id, self.pad_id, self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng) self.masked_lm_prob, np_rng,
self.binary_head)
def get_samples_mapping_(indexed_dataset, def get_samples_mapping_(indexed_dataset,
...@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset, ...@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length, max_seq_length,
short_seq_prob, short_seq_prob,
seed, seed,
name): name,
binary_head):
if not num_epochs: if not num_epochs:
if not max_num_samples: if not max_num_samples:
raise ValueError("Need to specify either max_num_samples " raise ValueError("Need to specify either max_num_samples "
...@@ -137,7 +141,8 @@ def get_samples_mapping_(indexed_dataset, ...@@ -137,7 +141,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length - 3, # account for added tokens max_seq_length - 3, # account for added tokens
short_seq_prob, short_seq_prob,
seed, seed,
verbose) verbose,
2 if binary_head else 1)
print_rank_0(' > done building sapmles index maping') print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True) np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format( print_rank_0(' > saved the index mapping in {}'.format(
...@@ -173,7 +178,7 @@ def build_training_sample(sample, ...@@ -173,7 +178,7 @@ def build_training_sample(sample,
target_seq_length, max_seq_length, target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict, vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id, cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng): masked_lm_prob, np_rng, binary_head):
"""Biuld training sample. """Biuld training sample.
Arguments: Arguments:
...@@ -193,12 +198,21 @@ def build_training_sample(sample, ...@@ -193,12 +198,21 @@ def build_training_sample(sample,
the opper bound whereas the numpy one is exclusive. the opper bound whereas the numpy one is exclusive.
""" """
# We assume that we have at least two sentences in the sample if binary_head:
assert len(sample) > 1 # We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B). # Divide sample into two segments (A and B).
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) if binary_head:
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
np_rng)
else:
tokens_a = []
for j in range(len(sample)):
tokens_a.extend(sample[j])
tokens_b = []
is_next_random = False
# Truncate to `target_sequence_length`. # Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length max_num_tokens = target_seq_length
......
import os
import time
import numpy as np
import torch
from megatron import get_args, get_tokenizer, mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, \
pad_and_convert_to_numpy
from megatron.data.data_samplers import MegatronPretrainingSampler
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
args = get_args()
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
num_workers = args.num_workers
# Use megatron's sampler with consumed samples set to 0 as
# this is only for evaluation and don't intend to resume half way.
# Also, set the drop last to false as don't intend to remove
# the last batch
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=0,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size(),
drop_last=False)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ict_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_mask',
'context_tokens', 'context_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_mask = data_b['query_mask'] < 0.5
context_tokens = data_b['context_tokens'].long()
context_mask = data_b['context_mask'] < 0.5
block_indices = data_b['block_data'].long()
return query_tokens, query_mask,\
context_tokens, context_mask, block_indices
def join_str_list(str_list):
"""Join a list of strings, handling spaces appropriately"""
result = ""
for s in str_list:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
class BlockSampleData(object):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def __init__(self, start_idx, end_idx, doc_idx, block_idx):
self.start_idx = start_idx
self.end_idx = end_idx
self.doc_idx = doc_idx
self.block_idx = block_idx
def as_array(self):
return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
def as_tuple(self):
return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
class BlockSamplesMapping(object):
def __init__(self, mapping_array):
# make sure that the array is compatible with BlockSampleData
assert mapping_array.shape[1] == 4
self.mapping_array = mapping_array
def __len__(self):
return self.mapping_array.shape[0]
def __getitem__(self, idx):
"""Get the data associated with an indexed sample."""
sample_data = BlockSampleData(*self.mapping_array[idx])
return sample_data
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{}s'.format(seed)
if use_one_sent_docs:
indexmap_filename += '_1sentok'
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if mpu.get_data_parallel_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert block_dataset.doc_idx.dtype == np.int64
assert block_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
from megatron.data import helpers
mapping_array = helpers.build_blocks_mapping(
block_dataset.doc_idx,
block_dataset.sizes,
title_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length - 3, # account for added tokens
seed,
verbose,
use_one_sent_docs)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, mapping_array, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
mapping_array.shape[0]))
return samples_mapping
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import torch import torch
import random
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
...@@ -30,12 +30,23 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -30,12 +30,23 @@ def build_pretraining_data_loader(dataset, consumed_samples):
args = get_args() args = get_args()
# Megatron sampler # Megatron sampler
batch_sampler = MegatronPretrainingSampler( if args.dataloader_type == 'single':
total_samples=len(dataset), batch_sampler = MegatronPretrainingSampler(
consumed_samples=consumed_samples, total_samples=len(dataset),
micro_batch_size=args.micro_batch_size, consumed_samples=consumed_samples,
data_parallel_rank=mpu.get_data_parallel_rank(), micro_batch_size=args.micro_batch_size,
data_parallel_size=mpu.get_data_parallel_world_size()) data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
# Torch dataloader. # Torch dataloader.
return torch.utils.data.DataLoader(dataset, return torch.utils.data.DataLoader(dataset,
...@@ -43,19 +54,18 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -43,19 +54,18 @@ def build_pretraining_data_loader(dataset, consumed_samples):
num_workers=args.num_workers, num_workers=args.num_workers,
pin_memory=True) pin_memory=True)
class MegatronPretrainingSampler: class MegatronPretrainingSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size): data_parallel_rank, data_parallel_size, drop_last=True):
# Keep a copy of input params for later use. # Keep a copy of input params for later use.
self.total_samples = total_samples self.total_samples = total_samples
self.consumed_samples = consumed_samples self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \ self.micro_batch_times_data_parallel_size = \
data_parallel_size self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
# Sanity checks. # Sanity checks.
assert self.total_samples > 0, \ assert self.total_samples > 0, \
...@@ -69,18 +79,79 @@ class MegatronPretrainingSampler: ...@@ -69,18 +79,79 @@ class MegatronPretrainingSampler:
'data_parallel_rank should be smaller than data size: {}, ' \ 'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size) '{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self): def __len__(self):
return self.total_samples return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx
def __iter__(self): def __iter__(self):
batch = [] batch = []
# Last batch if not complete will be dropped. # Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples): for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx) batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size: if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx = self.data_parallel_rank * self.micro_batch_size start_idx, end_idx = self.get_start_end_idx()
end_idx = start_idx + self.micro_batch_size
yield batch[start_idx:end_idx] yield batch[start_idx:end_idx]
batch = [] batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []
...@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): ...@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length.""" """Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens) #print(len_a, len_b, max_num_tokens)
assert len_a > 0 assert len_a > 0
assert len_b > 0
if len_a + len_b <= max_num_tokens: if len_a + len_b <= max_num_tokens:
return False return False
while len_a + len_b > max_num_tokens: while len_a + len_b > max_num_tokens:
...@@ -150,10 +149,11 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): ...@@ -150,10 +149,11 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
for token in tokens_b: for token in tokens_b:
tokens.append(token) tokens.append(token)
tokentypes.append(1) tokentypes.append(1)
# [SEP]. if tokens_b:
tokens.append(sep_id) # [SEP].
tokentypes.append(1) tokens.append(sep_id)
tokentypes.append(1)
return tokens, tokentypes return tokens, tokentypes
...@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup, short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'): dataset_type='standard_bert'):
if len(data_prefix) == 1: if len(data_prefix) == 1:
...@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, short_seq_prob, seed,
skip_warmup, skip_warmup,
binary_head,
dataset_type=dataset_type) dataset_type=dataset_type)
# Blending dataset. # Blending dataset.
# Parse the values. # Parse the values.
...@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string, prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob, max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, dataset_type=dataset_type) seed, skip_warmup, binary_head, dataset_type=dataset_type)
if train_ds: if train_ds:
train_datasets.append(train_ds) train_datasets.append(train_ds)
if valid_ds: if valid_ds:
...@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup, short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'): dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES: if dataset_type not in DSET_TYPES:
...@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
seed=seed seed=seed,
binary_head=binary_head
) )
if dataset_type == DSET_TYPE_ICT: if dataset_type == DSET_TYPE_ICT:
......
...@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio, ...@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length, const int32_t max_length,
std::mt19937& rand32_gen) { std::mt19937& rand32_gen) {
/* Training sample length. */ /* Training sample length. */
if (short_seq_ratio == 0) {
return max_length;
}
const auto random_number = rand32_gen(); const auto random_number = rand32_gen();
if ((random_number % short_seq_ratio) == 0) { if ((random_number % short_seq_ratio) == 0) {
return 2 + random_number % (max_length - 1); return 2 + random_number % (max_length - 1);
...@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
const int32_t max_seq_length, const int32_t max_seq_length,
const double short_seq_prob, const double short_seq_prob,
const int32_t seed, const int32_t seed,
const bool verbose) { const bool verbose,
const int32_t min_num_sent) {
/* Build a mapping of (start-index, end-index, sequence-length) where /* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length. and sequence-length is the target sequence length.
...@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// Consistency checks. // Consistency checks.
assert(num_epochs > 0); assert(num_epochs > 0);
assert(max_seq_length > 1); assert(max_seq_length > 1);
assert(short_seq_prob > 0.0); assert(short_seq_prob >= 0.0);
assert(short_seq_prob <= 1.0); assert(short_seq_prob <= 1.0);
assert(seed > 0); assert(seed > 0);
...@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
auto sizes = sizes_.unchecked<1>(); auto sizes = sizes_.unchecked<1>();
// For efficiency, convert probability to ratio. Note: rand() generates int. // For efficiency, convert probability to ratio. Note: rand() generates int.
const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob)); int32_t short_seq_ratio = 0;
if (short_seq_prob > 0) {
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
}
if (verbose) { if (verbose) {
const auto sent_start_index = docs[0]; const auto sent_start_index = docs[0];
...@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
} }
// If we have more than two sentences. // If we have more than two sentences.
if ((num_remain_sent > 1) && (!contains_long_sentence)) { if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values. // Set values.
auto seq_len = int32_t{0}; auto seq_len = int32_t{0};
...@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// and if we have reached end of the document. // and if we have reached end of the document.
if (((seq_len >= target_seq_len) && if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) && (num_remain_sent > 1) &&
(num_sent > 1) ) || (num_remain_sent == 0)) { (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Check for overflow. // Check for overflow.
if ((3 * map_index + 2) > if ((3 * map_index + 2) >
...@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
const int max_seq_length, const int max_seq_length,
const double short_seq_prob, const double short_seq_prob,
const int seed, const int seed,
const bool verbose) { const bool verbose,
const int32_t min_num_sent) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) { if (verbose) {
...@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
} }
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs, return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length, max_num_samples, max_seq_length,
short_seq_prob, seed, verbose); short_seq_prob, seed, verbose,
min_num_sent);
} else { } else {
if (verbose) { if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush; cout << " using uint32 for data mapping..." << endl << std::flush;
} }
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs, return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length, max_num_samples, max_seq_length,
short_seq_prob, seed, verbose); short_seq_prob, seed, verbose,
min_num_sent);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment