Commit 762072e0 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/retro' into 'main'

Retro

See merge request ADLR/megatron-lm!489
parents ef59b687 17a6044d
...@@ -3,3 +3,4 @@ __pycache__ ...@@ -3,3 +3,4 @@ __pycache__
build build
.coverage_* .coverage_*
*.egg-info *.egg-info
*~
...@@ -42,7 +42,9 @@ The following table shows both model (MFU) and hardware (HFU) FLOPs utilization ...@@ -42,7 +42,9 @@ The following table shows both model (MFU) and hardware (HFU) FLOPs utilization
* [Distributed Pretraining](#distributed-pretraining) * [Distributed Pretraining](#distributed-pretraining)
* [Activation Checkpointing and Recomputation](#activation-checkpointing-and-recomputation) * [Activation Checkpointing and Recomputation](#activation-checkpointing-and-recomputation)
* [Distributed Optimizer](#distributed-optimizer) * [Distributed Optimizer](#distributed-optimizer)
* [FlashAttention](#flashattention)
* [GPT-3 Example](#gpt-3-example) * [GPT-3 Example](#gpt-3-example)
* [Retro](#retro)
* [Evaluation and Tasks](#evaluation-and-tasks) * [Evaluation and Tasks](#evaluation-and-tasks)
* [GPT Text Generation](#gpt-text-generation) * [GPT Text Generation](#gpt-text-generation)
* [GPT Evaluation](#gpt-evaluation) * [GPT Evaluation](#gpt-evaluation)
...@@ -323,6 +325,19 @@ In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to config ...@@ -323,6 +325,19 @@ In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to config
With full global batch size of 1536 on 1024 A100 GPUs, each iteration takes around 32 seconds resulting in 138 teraFLOPs per GPU which is 44% of the theoretical peak FLOPs. With full global batch size of 1536 on 1024 A100 GPUs, each iteration takes around 32 seconds resulting in 138 teraFLOPs per GPU which is 44% of the theoretical peak FLOPs.
## Retro
See:
- `tools/retro/README.md` for an overview.
- `tools/retro/examples/get_preprocess_cmd.sh` for an example of common preprocessing arguments.
- `tools/retro/examples/preprocess_data.sh` for an example of how to preprocess data.
- `tools/retro/examples/pretrain_model.sh` for an example of how to pretrain a model.
Retro is a retrieval-enhanced model that is based on GPT. As described in [Improving language models by retrieving from trillions of tokens](https://arxiv.org/abs/2112.04426), Retro retrieves from a database of document chunks by performing locality search using a sample's tokens. The retrieval database can be large -- often billions or even trillions of tokens -- and provides a more efficient storage mechanism of factual knowledge, when compared to storing factual knowledge implicitly within the network's parameters.
Using Retro requires two steps: 1) preprocessing the retrieval database and pretraining neighbors, and 2) pretraining a model using this data. Please see `tools/retro/README.md` for a detailed overview.
<!-- <!--
## REALM Pipeline ## REALM Pipeline
We are working on implementing the [REALM](https://arxiv.org/pdf/2002.08909.pdf) system. The following sections (will) reflect the three stages of training it. For now it's just the ICT code. We are working on implementing the [REALM](https://arxiv.org/pdf/2002.08909.pdf) system. The following sections (will) reflect the three stages of training it. For now it's just the ICT code.
......
...@@ -11,10 +11,10 @@ import sys ...@@ -11,10 +11,10 @@ import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir))) os.path.pardir, os.path.pardir)))
from megatron import get_args from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import print_rank_0
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType from megatron.model import GPTModel, ModelType
......
...@@ -10,10 +10,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ...@@ -10,10 +10,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir))) os.path.pardir, os.path.pardir)))
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.core import mpu
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import GPTModel from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch import torch
from .global_vars import get_args from .global_vars import get_args, get_retro_args
from .global_vars import get_current_global_batch_size from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler from .global_vars import get_signal_handler
......
...@@ -3,9 +3,14 @@ ...@@ -3,9 +3,14 @@
"""Megatron arguments.""" """Megatron arguments."""
import argparse import argparse
import json
import os import os
import torch import torch
import types
from megatron.global_vars import set_retro_args, get_retro_args
from tools.retro.utils import get_args_path as get_retro_args_path
def parse_args(extra_args_provider=None, ignore_unknown_args=False): def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments.""" """Parse all arguments."""
...@@ -29,6 +34,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): ...@@ -29,6 +34,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_logging_args(parser) parser = _add_logging_args(parser)
parser = _add_inference_args(parser) parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser) parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -333,7 +339,6 @@ def validate_args(args, defaults={}): ...@@ -333,7 +339,6 @@ def validate_args(args, defaults={}):
if args.sequence_parallel: if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False args.async_tensor_model_parallel_allreduce = False
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel: if args.sequence_parallel:
raise RuntimeError( raise RuntimeError(
...@@ -344,15 +349,31 @@ def validate_args(args, defaults={}): ...@@ -344,15 +349,31 @@ def validate_args(args, defaults={}):
"Using async gradient all reduce requires setting the environment " "Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1") "variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
# Load retro args.
if args.retro_workdir:
retro_args_path = get_retro_args_path(args.retro_workdir)
if os.path.exists(retro_args_path):
with open(retro_args_path) as f:
retro_args = types.SimpleNamespace(**json.load(f))
retro_args.retro_return_doc_ids = args.retro_return_doc_ids
retro_args.retro_gpt_retrieved_length = \
args.retro_num_retrieved_chunks * \
retro_args.retro_gpt_chunk_length
set_retro_args(retro_args)
# Print arguments.
_print_args("arguments", args)
retro_args = get_retro_args()
if retro_args and args != retro_args:
_print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank))
_print_args(args)
return args return args
def _print_args(args): def _print_args(title, args):
"""Print arguments.""" """Print arguments."""
if args.rank == 0: if args.rank == 0:
print('------------------------ arguments ------------------------', print(f'------------------------ {title} ------------------------',
flush=True) flush=True)
str_list = [] str_list = []
for arg in vars(args): for arg in vars(args):
...@@ -360,7 +381,7 @@ def _print_args(args): ...@@ -360,7 +381,7 @@ def _print_args(args):
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()): for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True) print(arg, flush=True)
print('-------------------- end of arguments ---------------------', print(f'-------------------- end of {title} ---------------------',
flush=True) flush=True)
...@@ -403,15 +424,67 @@ def _add_inference_args(parser): ...@@ -403,15 +424,67 @@ def _add_inference_args(parser):
help='During inference, if batch-size times ' help='During inference, if batch-size times '
'sequence-length is smaller than this threshold ' 'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.') 'then we will not use pipelining, otherwise we will.')
group.add_argument('--max-tokens-to-oom', group.add_argument('--max-tokens-to-oom',
type=int, default=12000, type=int, default=12000,
help='Maximum number of tokens during inference' help='Maximum number of tokens during inference'
'tokens here is # in prompt + # to generate' 'tokens here is # in prompt + # to generate'
'Allows us to throw an error before OOM crashes server') 'Allows us to throw an error before OOM crashes server')
group.add_argument('--output-bert-embeddings', action='store_true',
help='Output Bert embeddings (via mean pooling) from '
'model, rather than its binary head output or entire '
'hidden batch.')
group.add_argument('--bert-embedder-type', default="megatron",
choices=["megatron", "huggingface"],
help='Select either Megatron or Huggingface as the '
'Bert embedder.')
return parser return parser
def _add_retro_args(parser):
group = parser.add_argument_group(title='retro')
group.add_argument('--retro-workdir', default=None,
help='Retro working directory, which contains the '
'preprocessed data for for pretraining. This directory '
'is built during preprocessing (see '
'tools/retro/README.md), and contains subdirectories '
'for the chunk database and pretraining neighbors.')
group.add_argument('--retro-add-retriever',
action='store_true', default=False,
help='Add a retriever to the transformer, for use in '
'pretraining a Retro model.')
group.add_argument('--retro-cyclic-train-iters', type=int, default=None,
help='Set number of training iterations for cyclic '
'Retro training.')
group.add_argument('--retro-encoder-layers', type=int, default=2,
help='Number of layers to use for the retrieval '
'encoder.')
group.add_argument('--retro-encoder-hidden-dropout',
type=float, default=0.1, help='Hidden dropout for '
'retrieval encoder.')
group.add_argument('--retro-encoder-attention-dropout',
type=float, default=0.1, help='Attention dropout for '
'retrieval encoder.')
group.add_argument("--retro-num-neighbors", type=int, default=2,
help='Number of neighbors to retrieve during '
'pretraining.')
group.add_argument("--retro-num-retrieved-chunks", type=int, default=2,
help='Number of chunks to retrieve from the retrieval '
'database.')
group.add_argument("--retro-return-doc-ids", action="store_true",
help="Turn this on when preprocessing retro data.")
# Enforce argument naming convention.
for action in group._group_actions:
prefix = action.dest.split("_")[0]
assert prefix == "retro", \
"Retro args must be prefixed with '--retro-*', for consistent " \
"styling. Please fix '%s'." % ", ".join(action.option_strings)
return parser
def _add_network_size_args(parser): def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size') group = parser.add_argument_group(title='network size')
...@@ -775,6 +848,10 @@ def _add_checkpointing_args(parser): ...@@ -775,6 +848,10 @@ def _add_checkpointing_args(parser):
group.add_argument('--use-checkpoint-args', action='store_true', group.add_argument('--use-checkpoint-args', action='store_true',
help='Override any command line arguments with arguments ' help='Override any command line arguments with arguments '
'from the checkpoint') 'from the checkpoint')
group.add_argument('--exit-on-missing-checkpoint', action='store_true',
help="If '--load' is set, but checkpoint is not found "
"(e.g., path typo), then exit instead of random "
"initialization.")
return parser return parser
...@@ -835,6 +912,8 @@ def _add_distributed_args(parser): ...@@ -835,6 +912,8 @@ def _add_distributed_args(parser):
group.add_argument('--distributed-backend', default='nccl', group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'], choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.') help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
group.add_argument('--DDP-impl', default='local', group.add_argument('--DDP-impl', default='local',
choices=['local', 'torch'], choices=['local', 'torch'],
help='which DistributedDataParallel implementation ' help='which DistributedDataParallel implementation '
......
...@@ -508,7 +508,16 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -508,7 +508,16 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
use_distributed_optimizer=args.use_distributed_optimizer, use_distributed_optimizer=args.use_distributed_optimizer,
rank0=False) rank0=False)
# Checkpoint not loaded.
if model_state_dict is None: if model_state_dict is None:
# Conditionally exit at this point.
if args.exit_on_missing_checkpoint:
print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<")
torch.distributed.barrier()
sys.exit()
# Iteration defaults to 0.
return 0 return 0
# set checkpoint version # set checkpoint version
......
...@@ -19,7 +19,6 @@ from megatron.data.dataset_utils import ( ...@@ -19,7 +19,6 @@ from megatron.data.dataset_utils import (
create_masked_lm_predictions create_masked_lm_predictions
) )
class BertDataset(torch.utils.data.Dataset): class BertDataset(torch.utils.data.Dataset):
def __init__(self, name, indexed_dataset, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
...@@ -156,7 +155,9 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, ...@@ -156,7 +155,9 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
# Some checks. # Some checks.
num_tokens = len(tokens) num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens padding_length = max_seq_length - num_tokens
assert padding_length >= 0 assert padding_length >= 0, \
f"num_tokens ({num_tokens}) is greater than " \
"max_seq_length ({max_seq_length})."
assert len(tokentypes) == num_tokens assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels) assert len(masked_positions) == len(masked_labels)
......
...@@ -50,4 +50,7 @@ class BlendableDataset(torch.utils.data.Dataset): ...@@ -50,4 +50,7 @@ class BlendableDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
dataset_idx = self.dataset_index[idx] dataset_idx = self.dataset_index[idx]
sample_idx = self.dataset_sample_index[idx] sample_idx = self.dataset_sample_index[idx]
return self.datasets[dataset_idx][sample_idx] return {
"dataset_idx" : dataset_idx,
**self.datasets[dataset_idx][sample_idx],
}
...@@ -452,7 +452,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -452,7 +452,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string, prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob, max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, dataset_type=dataset_type) seed, skip_warmup, binary_head, max_seq_length_dec,
dataset_type=dataset_type)
if train_ds: if train_ds:
train_datasets.append(train_ds) train_datasets.append(train_ds)
if valid_ds: if valid_ds:
...@@ -460,7 +461,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -460,7 +461,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
if test_ds: if test_ds:
test_datasets.append(test_ds) test_datasets.append(test_ds)
# Blend. # Blend.
blending_train_dataset = None blending_train_dataset = None
if train_datasets: if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights) blending_train_dataset = BlendableDataset(train_datasets, weights)
......
...@@ -16,15 +16,18 @@ from megatron.data.dataset_utils import get_train_valid_test_split_ ...@@ -16,15 +16,18 @@ from megatron.data.dataset_utils import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def build_train_valid_test_datasets(data_prefix, data_impl, def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
splits_string, train_valid_test_num_samples, train_valid_test_num_samples,
seq_length, seed, skip_warmup, seq_length, seed, skip_warmup,
train_data_prefix=None, valid_data_prefix=None, train_data_prefix=None,
test_data_prefix=None,): valid_data_prefix=None,
test_data_prefix=None,
return_doc_ids=False):
"""Build train, valid, and test datasets.""" """Build train, valid, and test datasets."""
if data_prefix: if data_prefix:
print_rank_0("Single data path provided for train, valid & test") print_rank_0("Single data path provided for train, valid & test")
# Single dataset. # Single dataset.
if len(data_prefix) == 1: if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0], return _build_train_valid_test_datasets(data_prefix[0],
...@@ -35,7 +38,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, ...@@ -35,7 +38,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl,
# Blending dataset. # Blending dataset.
# Parse the values. # Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix, output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples) train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets. # Build individual datasets.
...@@ -46,7 +49,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, ...@@ -46,7 +49,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl,
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string, prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup) seq_length, seed, skip_warmup,
return_doc_ids)
if train_ds: if train_ds:
train_datasets.append(train_ds) train_datasets.append(train_ds)
if valid_ds: if valid_ds:
...@@ -67,6 +71,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, ...@@ -67,6 +71,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl,
return (blending_train_dataset, blending_valid_dataset, return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset) blending_test_dataset)
else: else:
print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.") print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.")
...@@ -74,23 +79,69 @@ def build_train_valid_test_datasets(data_prefix, data_impl, ...@@ -74,23 +79,69 @@ def build_train_valid_test_datasets(data_prefix, data_impl,
# Single dataset. # Single dataset.
if train_data_prefix is not None: if train_data_prefix is not None:
train_dataset = build_dataset("train", train_data_prefix, data_impl, train_dataset = build_dataset("train", train_data_prefix, data_impl,
train_valid_test_num_samples[0], seq_length, seed, train_valid_test_num_samples[0],
skip_warmup) seq_length, seed, skip_warmup)
if valid_data_prefix is not None: if valid_data_prefix is not None:
valid_dataset = build_dataset("valid", valid_data_prefix, data_impl, valid_dataset = build_dataset("valid", valid_data_prefix, data_impl,
train_valid_test_num_samples[1], seq_length, seed, train_valid_test_num_samples[1],
False) seq_length, seed, False)
if test_data_prefix is not None: if test_data_prefix is not None:
test_dataset = build_dataset("test", test_data_prefix, data_impl, test_dataset = build_dataset("test", test_data_prefix, data_impl,
train_valid_test_num_samples[2], seq_length, seed, train_valid_test_num_samples[2],
False) seq_length, seed, False)
return (train_dataset, valid_dataset, test_dataset) return (train_dataset, valid_dataset, test_dataset)
def build_dataset(dataset_name, data_prefix, data_impl, num_samples, seq_length, seed, skip_warmup): def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup,
return_doc_ids=False):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = GPTDataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed,
return_doc_ids)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def build_dataset(dataset_name, data_prefix, data_impl, num_samples,
seq_length, seed, skip_warmup):
dataset = None dataset = None
if len(data_prefix) == 1: if len(data_prefix) == 1:
dataset = _build_dataset(dataset_name, dataset = _build_dataset(dataset_name,
...@@ -119,7 +170,7 @@ def build_dataset(dataset_name, data_prefix, data_impl, num_samples, seq_length, ...@@ -119,7 +170,7 @@ def build_dataset(dataset_name, data_prefix, data_impl, num_samples, seq_length,
def _build_dataset(dataset_name, data_prefix, data_impl, def _build_dataset(dataset_name, data_prefix, data_impl,
num_samples, seq_length, seed, skip_warmup): num_samples, seq_length, seed, skip_warmup):
""" """
Build dataset. This method is called when individual Build dataset. This method is called when individual
train, valid, test datasets are provided train, valid, test datasets are provided
...@@ -146,49 +197,6 @@ def _build_dataset(dataset_name, data_prefix, data_impl, ...@@ -146,49 +197,6 @@ def _build_dataset(dataset_name, data_prefix, data_impl,
return dataset return dataset
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = GPTDataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
"""Build indexed dataset.""" """Build indexed dataset."""
print_rank_0(' > building dataset index ...') print_rank_0(' > building dataset index ...')
...@@ -208,19 +216,23 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): ...@@ -208,19 +216,23 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
class GPTDataset(torch.utils.data.Dataset): class GPTDataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, documents, indexed_dataset, def __init__(self, name, data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed): num_samples, seq_length, seed,
return_doc_ids=False):
self.name = name self.name = name
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
self.return_doc_ids = return_doc_ids
# Checks # Checks
assert np.min(documents) >= 0 assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0] assert np.max(documents) < indexed_dataset.sizes.shape[0]
# Build index mappings. # Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( self.doc_idx, self.sample_idx, self.shuffle_idx, self.index_prefix = \
self.name, data_prefix, documents, self.indexed_dataset.sizes, _build_index_mappings(self.name, data_prefix,
num_samples, seq_length, seed) documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed)
def __len__(self): def __len__(self):
# -1 is due to data structure used to retieve the index: # -1 is due to data structure used to retieve the index:
...@@ -236,24 +248,33 @@ class GPTDataset(torch.utils.data.Dataset): ...@@ -236,24 +248,33 @@ class GPTDataset(torch.utils.data.Dataset):
offset_f = self.sample_idx[idx][1] offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx + 1][1] offset_l = self.sample_idx[idx + 1][1]
# If we are within the same document, just extract the chunk. # If we are within the same document, just extract the chunk.
doc_ids = []
if doc_index_f == doc_index_l: if doc_index_f == doc_index_l:
doc_ids.append(self.doc_idx[doc_index_f])
sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f, offset=offset_f,
length=offset_l - offset_f + 1) length=offset_l - offset_f + 1)
else: else:
# Otherwise, get the rest of the initial document. # Otherwise, get the rest of the initial document.
doc_ids.append(self.doc_idx[doc_index_f])
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f)] offset=offset_f)]
# Loop over all in between documents and add the entire document. # Loop over all in between documents and add the entire document.
for i in range(doc_index_f + 1, doc_index_l): for i in range(doc_index_f + 1, doc_index_l):
doc_ids.append(self.doc_idx[i])
sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document. # And finally add the relevant portion of last document.
doc_ids.append(self.doc_idx[doc_index_l])
sample_list.append(self.indexed_dataset.get( sample_list.append(self.indexed_dataset.get(
self.doc_idx[doc_index_l], self.doc_idx[doc_index_l],
length=offset_l + 1)) length=offset_l + 1))
sample = np.concatenate(sample_list) sample = np.concatenate(sample_list)
return {'text': np.array(sample, dtype=np.int64)} if self.return_doc_ids: # for retro preprocessing
return {'text': np.array(sample, dtype=np.int64),
'doc_ids': np.array(doc_ids, dtype=np.int64)}
else:
return {'text': np.array(sample, dtype=np.int64)}
def _build_index_mappings(name, data_prefix, documents, sizes, def _build_index_mappings(name, data_prefix, documents, sizes,
...@@ -267,15 +288,16 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -267,15 +288,16 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# Number of tokens in each epoch and number of required epochs. # Number of tokens in each epoch and number of required epochs.
tokens_per_epoch = _num_tokens(documents, sizes) tokens_per_epoch = _num_tokens(documents, sizes)
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
# rng state # rng state
np_rng = np.random.RandomState(seed=seed) np_rng = np.random.RandomState(seed=seed)
# Filename of the index mappings. # Filename of the index mappings.
_filename = data_prefix index_prefix = '{}_indexmap'.format(name)
_filename += '_{}_indexmap'.format(name) index_prefix += '_{}ns'.format(num_samples)
_filename += '_{}ns'.format(num_samples) index_prefix += '_{}sl'.format(seq_length)
_filename += '_{}sl'.format(seq_length) index_prefix += '_{}s'.format(seed)
_filename += '_{}s'.format(seed) _filename = data_prefix + '_' + index_prefix
doc_idx_filename = _filename + '_doc_idx.npy' doc_idx_filename = _filename + '_doc_idx.npy'
sample_idx_filename = _filename + '_sample_idx.npy' sample_idx_filename = _filename + '_sample_idx.npy'
shuffle_idx_filename = _filename + '_shuffle_idx.npy' shuffle_idx_filename = _filename + '_shuffle_idx.npy'
...@@ -343,8 +365,6 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -343,8 +365,6 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
assert sizes.dtype == np.int32 assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch) num_epochs, tokens_per_epoch)
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True) np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save sample-idx mapping ' print_rank_0(' > elasped time to build and save sample-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time)) '(seconds): {:4f}'.format(time.time() - start_time))
...@@ -389,7 +409,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -389,7 +409,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
sample_idx.shape[0])) sample_idx.shape[0]))
print_rank_0(' total number of epochs: {}'.format(num_epochs)) print_rank_0(' total number of epochs: {}'.format(num_epochs))
return doc_idx, sample_idx, shuffle_idx return doc_idx, sample_idx, shuffle_idx, index_prefix
def _num_tokens(documents, sizes): def _num_tokens(documents, sizes):
...@@ -481,7 +501,7 @@ def _build_shuffle_idx(num_samples, total_size, np_rng): ...@@ -481,7 +501,7 @@ def _build_shuffle_idx(num_samples, total_size, np_rng):
"""Build the range [0, size) and shuffle.""" """Build the range [0, size) and shuffle."""
print(' > building shuffle index with split [0, {}) and [{}, {}) ' print(' > building shuffle index with split [0, {}) and [{}, {}) '
'...'.format(num_samples, num_samples, total_size), flush=True) '...'.format(num_samples, num_samples, total_size), flush=True)
dtype_ = np.uint32 dtype_ = np.uint32
if total_size >= (np.iinfo(np.uint32).max - 1): if total_size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64 dtype_ = np.int64
......
...@@ -460,7 +460,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -460,7 +460,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
return self._path return self._path
def __setstate__(self, state): def __setstate__(self, state):
self._do_init(state) self._do_init(state, skip_warmup=True)
def _do_init(self, path, skip_warmup): def _do_init(self, path, skip_warmup):
self._path = path self._path = path
......
...@@ -12,6 +12,7 @@ from .microbatches import build_num_microbatches_calculator ...@@ -12,6 +12,7 @@ from .microbatches import build_num_microbatches_calculator
from .timers import Timers from .timers import Timers
_GLOBAL_ARGS = None _GLOBAL_ARGS = None
_GLOBAL_RETRO_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None _GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None _GLOBAL_TENSORBOARD_WRITER = None
...@@ -25,6 +26,11 @@ def get_args(): ...@@ -25,6 +26,11 @@ def get_args():
return _GLOBAL_ARGS return _GLOBAL_ARGS
def get_retro_args():
"""Return retro arguments."""
return _GLOBAL_RETRO_ARGS
def get_num_microbatches(): def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
...@@ -98,6 +104,11 @@ def set_args(args): ...@@ -98,6 +104,11 @@ def set_args(args):
_GLOBAL_ARGS = args _GLOBAL_ARGS = args
def set_retro_args(retro_args):
global _GLOBAL_RETRO_ARGS
_GLOBAL_RETRO_ARGS = retro_args
def _build_num_microbatches_calculator(args): def _build_num_microbatches_calculator(args):
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
......
...@@ -174,7 +174,7 @@ def _initialize_distributed(): ...@@ -174,7 +174,7 @@ def _initialize_distributed():
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.distributed_backend, backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank, world_size=args.world_size, rank=args.rank,
timeout=timedelta(minutes=10)) timeout=timedelta(minutes=args.distributed_timeout_minutes))
# Set the tensor model-parallel, pipeline model-parallel, and # Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators. # data-parallel communicators.
......
...@@ -16,6 +16,7 @@ from megatron.model.utils import init_method_normal ...@@ -16,6 +16,7 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
def bert_extended_attention_mask(attention_mask): def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s] # [b, 1, s]
...@@ -137,6 +138,10 @@ class BertModel(MegatronModule): ...@@ -137,6 +138,10 @@ class BertModel(MegatronModule):
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.return_embeddings = args.output_bert_embeddings
if self.return_embeddings:
assert self.post_process and self.add_binary_head
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std, scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers) args.num_layers)
...@@ -182,6 +187,24 @@ class BertModel(MegatronModule): ...@@ -182,6 +187,24 @@ class BertModel(MegatronModule):
if self.post_process and self.add_binary_head: if self.post_process and self.add_binary_head:
lm_output, pooled_output = lm_output lm_output, pooled_output = lm_output
# Return pooled output (e.g., when computing Bert embeddings).
if self.return_embeddings:
# Sum attention mask.
embeddings = torch.transpose(lm_output, 0, 1)
masks = torch.sum(attention_mask, dim=1)
# Collect masked embeddings.
output = torch.zeros(
size=(embeddings.shape[0], embeddings.shape[2]),
dtype=torch.float32,
device=torch.cuda.current_device())
for i, (embedding, mask) in enumerate(zip(embeddings, masks)):
output[i, :] = torch.mean(embedding[1: mask - 1], dim=0)
return output
else: else:
pooled_output = None pooled_output = None
......
...@@ -74,13 +74,17 @@ class GPTModel(MegatronModule): ...@@ -74,13 +74,17 @@ class GPTModel(MegatronModule):
"""See megatron.model.transformer.set_input_tensor()""" """See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, inference_params=None): ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
labels=None, tokentype_ids=None, inference_params=None):
lm_output = self.language_model( lm_output = self.language_model(
input_ids, input_ids,
position_ids, position_ids,
attention_mask, attention_mask,
ret_input_ids=ret_input_ids,
ret_position_ids=ret_position_ids,
ret_attn_mask=ret_attn_mask,
inference_params=inference_params) inference_params=inference_params)
if self.post_process: if self.post_process:
......
...@@ -7,11 +7,13 @@ import torch.nn.functional as F ...@@ -7,11 +7,13 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from .enums import LayerType, AttnMaskType
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType from .retro_transformer import ParallelRetroEncoder, ParallelRetroTransformer
from megatron.model.transformer import ParallelTransformer from .transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer from .utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal from .utils import init_method_normal, scaled_init_method_normal
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...@@ -349,17 +351,39 @@ class TransformerLanguageModel(MegatronModule): ...@@ -349,17 +351,39 @@ class TransformerLanguageModel(MegatronModule):
self.num_tokentypes) self.num_tokentypes)
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Transformer. # Retriever (bi-directional transformer with cross attention)
# Encoder (usually set to True, False if part of an encoder-decoder if args.retro_add_retriever:
# architecture and in encoder-only stage). self.retriever = ParallelRetroEncoder(
if self.add_encoder:
self.encoder = ParallelTransformer(
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type, self_attn_mask_type=AttnMaskType.padding,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process post_process=False,
) )
self._retriever_key = 'retriever'
else:
self.retriever = None
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if self.add_encoder:
if args.retro_add_retriever:
self.encoder = ParallelRetroTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
retriever=self.retriever,
)
else:
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._encoder_key = 'encoder' self._encoder_key = 'encoder'
else: else:
self.encoder = None self.encoder = None
...@@ -414,11 +438,19 @@ class TransformerLanguageModel(MegatronModule): ...@@ -414,11 +438,19 @@ class TransformerLanguageModel(MegatronModule):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, enc_dec_attn_mask=None, tokentype_ids=None,
inference_params=None, inference_params=None,
pooling_sequence_index=0, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
# Retriever embedding.
if self.retriever and self.pre_process:
retriever_input = self.embedding(ret_input_ids, ret_position_ids,
tokentype_ids=tokentype_ids)
else:
retriever_input = None
# Encoder embedding. # Encoder embedding.
if self.pre_process: if self.pre_process:
encoder_input = self.embedding(enc_input_ids, enc_position_ids, encoder_input = self.embedding(enc_input_ids, enc_position_ids,
...@@ -429,10 +461,18 @@ class TransformerLanguageModel(MegatronModule): ...@@ -429,10 +461,18 @@ class TransformerLanguageModel(MegatronModule):
# Run encoder. # Run encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
if self.encoder is not None: if self.encoder is not None:
encoder_output = self.encoder( if self.retriever:
encoder_input, encoder_output = self.encoder(
enc_attn_mask, encoder_input,
inference_params=inference_params) enc_attn_mask,
retriever_output=retriever_input,
retriever_attn_mask=ret_attn_mask,
inference_params=inference_params)
else:
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
inference_params=inference_params)
else: else:
encoder_output = self.encoder_hidden_state encoder_output = self.encoder_hidden_state
else: else:
......
This diff is collapsed.
...@@ -270,7 +270,6 @@ class CoreAttention(MegatronModule): ...@@ -270,7 +270,6 @@ class CoreAttention(MegatronModule):
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel: if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork(): with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
...@@ -477,6 +476,7 @@ class ParallelAttention(MegatronModule): ...@@ -477,6 +476,7 @@ class ParallelAttention(MegatronModule):
# ================================================= # =================================================
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
if inference_params: if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict: if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len inf_max_seq_len = inference_params.max_sequence_len
......
...@@ -244,6 +244,7 @@ def forward_backward_no_pipelining(forward_step_func, ...@@ -244,6 +244,7 @@ def forward_backward_no_pipelining(forward_step_func,
output_tensor = forward_step(forward_step_func, data_iterator, output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store, model, input_tensor, forward_data_store,
timers, collect_non_loss_data) timers, collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers) output_tensor_grad, timers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment