Commit 762072e0 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/retro' into 'main'

Retro

See merge request ADLR/megatron-lm!489
parents ef59b687 17a6044d
......@@ -3,3 +3,4 @@ __pycache__
build
.coverage_*
*.egg-info
*~
......@@ -42,7 +42,9 @@ The following table shows both model (MFU) and hardware (HFU) FLOPs utilization
* [Distributed Pretraining](#distributed-pretraining)
* [Activation Checkpointing and Recomputation](#activation-checkpointing-and-recomputation)
* [Distributed Optimizer](#distributed-optimizer)
* [FlashAttention](#flashattention)
* [GPT-3 Example](#gpt-3-example)
* [Retro](#retro)
* [Evaluation and Tasks](#evaluation-and-tasks)
* [GPT Text Generation](#gpt-text-generation)
* [GPT Evaluation](#gpt-evaluation)
......@@ -323,6 +325,19 @@ In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to config
With full global batch size of 1536 on 1024 A100 GPUs, each iteration takes around 32 seconds resulting in 138 teraFLOPs per GPU which is 44% of the theoretical peak FLOPs.
## Retro
See:
- `tools/retro/README.md` for an overview.
- `tools/retro/examples/get_preprocess_cmd.sh` for an example of common preprocessing arguments.
- `tools/retro/examples/preprocess_data.sh` for an example of how to preprocess data.
- `tools/retro/examples/pretrain_model.sh` for an example of how to pretrain a model.
Retro is a retrieval-enhanced model that is based on GPT. As described in [Improving language models by retrieving from trillions of tokens](https://arxiv.org/abs/2112.04426), Retro retrieves from a database of document chunks by performing locality search using a sample's tokens. The retrieval database can be large -- often billions or even trillions of tokens -- and provides a more efficient storage mechanism of factual knowledge, when compared to storing factual knowledge implicitly within the network's parameters.
Using Retro requires two steps: 1) preprocessing the retrieval database and pretraining neighbors, and 2) pretraining a model using this data. Please see `tools/retro/README.md` for a detailed overview.
<!--
## REALM Pipeline
We are working on implementing the [REALM](https://arxiv.org/pdf/2002.08909.pdf) system. The following sections (will) reflect the three stages of training it. For now it's just the ICT code.
......
......@@ -11,10 +11,10 @@ import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir)))
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType
......
......@@ -10,10 +10,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir)))
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.core import mpu
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from .global_vars import get_args
from .global_vars import get_args, get_retro_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler
......
......@@ -3,9 +3,14 @@
"""Megatron arguments."""
import argparse
import json
import os
import torch
import types
from megatron.global_vars import set_retro_args, get_retro_args
from tools.retro.utils import get_args_path as get_retro_args_path
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
......@@ -29,6 +34,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)
# Custom arguments.
if extra_args_provider is not None:
......@@ -333,7 +339,6 @@ def validate_args(args, defaults={}):
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
......@@ -344,15 +349,31 @@ def validate_args(args, defaults={}):
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
# Load retro args.
if args.retro_workdir:
retro_args_path = get_retro_args_path(args.retro_workdir)
if os.path.exists(retro_args_path):
with open(retro_args_path) as f:
retro_args = types.SimpleNamespace(**json.load(f))
retro_args.retro_return_doc_ids = args.retro_return_doc_ids
retro_args.retro_gpt_retrieved_length = \
args.retro_num_retrieved_chunks * \
retro_args.retro_gpt_chunk_length
set_retro_args(retro_args)
# Print arguments.
_print_args("arguments", args)
retro_args = get_retro_args()
if retro_args and args != retro_args:
_print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank))
_print_args(args)
return args
def _print_args(args):
def _print_args(title, args):
"""Print arguments."""
if args.rank == 0:
print('------------------------ arguments ------------------------',
print(f'------------------------ {title} ------------------------',
flush=True)
str_list = []
for arg in vars(args):
......@@ -360,7 +381,7 @@ def _print_args(args):
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print('-------------------- end of arguments ---------------------',
print(f'-------------------- end of {title} ---------------------',
flush=True)
......@@ -403,15 +424,67 @@ def _add_inference_args(parser):
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')
group.add_argument('--max-tokens-to-oom',
type=int, default=12000,
help='Maximum number of tokens during inference'
'tokens here is # in prompt + # to generate'
'Allows us to throw an error before OOM crashes server')
group.add_argument('--output-bert-embeddings', action='store_true',
help='Output Bert embeddings (via mean pooling) from '
'model, rather than its binary head output or entire '
'hidden batch.')
group.add_argument('--bert-embedder-type', default="megatron",
choices=["megatron", "huggingface"],
help='Select either Megatron or Huggingface as the '
'Bert embedder.')
return parser
def _add_retro_args(parser):
group = parser.add_argument_group(title='retro')
group.add_argument('--retro-workdir', default=None,
help='Retro working directory, which contains the '
'preprocessed data for for pretraining. This directory '
'is built during preprocessing (see '
'tools/retro/README.md), and contains subdirectories '
'for the chunk database and pretraining neighbors.')
group.add_argument('--retro-add-retriever',
action='store_true', default=False,
help='Add a retriever to the transformer, for use in '
'pretraining a Retro model.')
group.add_argument('--retro-cyclic-train-iters', type=int, default=None,
help='Set number of training iterations for cyclic '
'Retro training.')
group.add_argument('--retro-encoder-layers', type=int, default=2,
help='Number of layers to use for the retrieval '
'encoder.')
group.add_argument('--retro-encoder-hidden-dropout',
type=float, default=0.1, help='Hidden dropout for '
'retrieval encoder.')
group.add_argument('--retro-encoder-attention-dropout',
type=float, default=0.1, help='Attention dropout for '
'retrieval encoder.')
group.add_argument("--retro-num-neighbors", type=int, default=2,
help='Number of neighbors to retrieve during '
'pretraining.')
group.add_argument("--retro-num-retrieved-chunks", type=int, default=2,
help='Number of chunks to retrieve from the retrieval '
'database.')
group.add_argument("--retro-return-doc-ids", action="store_true",
help="Turn this on when preprocessing retro data.")
# Enforce argument naming convention.
for action in group._group_actions:
prefix = action.dest.split("_")[0]
assert prefix == "retro", \
"Retro args must be prefixed with '--retro-*', for consistent " \
"styling. Please fix '%s'." % ", ".join(action.option_strings)
return parser
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size')
......@@ -775,6 +848,10 @@ def _add_checkpointing_args(parser):
group.add_argument('--use-checkpoint-args', action='store_true',
help='Override any command line arguments with arguments '
'from the checkpoint')
group.add_argument('--exit-on-missing-checkpoint', action='store_true',
help="If '--load' is set, but checkpoint is not found "
"(e.g., path typo), then exit instead of random "
"initialization.")
return parser
......@@ -835,6 +912,8 @@ def _add_distributed_args(parser):
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
group.add_argument('--DDP-impl', default='local',
choices=['local', 'torch'],
help='which DistributedDataParallel implementation '
......
......@@ -508,7 +508,16 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
use_distributed_optimizer=args.use_distributed_optimizer,
rank0=False)
# Checkpoint not loaded.
if model_state_dict is None:
# Conditionally exit at this point.
if args.exit_on_missing_checkpoint:
print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<")
torch.distributed.barrier()
sys.exit()
# Iteration defaults to 0.
return 0
# set checkpoint version
......
......@@ -19,7 +19,6 @@ from megatron.data.dataset_utils import (
create_masked_lm_predictions
)
class BertDataset(torch.utils.data.Dataset):
def __init__(self, name, indexed_dataset, data_prefix,
......@@ -156,7 +155,9 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert padding_length >= 0, \
f"num_tokens ({num_tokens}) is greater than " \
"max_seq_length ({max_seq_length})."
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
......
......@@ -50,4 +50,7 @@ class BlendableDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
dataset_idx = self.dataset_index[idx]
sample_idx = self.dataset_sample_index[idx]
return self.datasets[dataset_idx][sample_idx]
return {
"dataset_idx" : dataset_idx,
**self.datasets[dataset_idx][sample_idx],
}
......@@ -452,7 +452,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, dataset_type=dataset_type)
seed, skip_warmup, binary_head, max_seq_length_dec,
dataset_type=dataset_type)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
......@@ -460,7 +461,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
if test_ds:
test_datasets.append(test_ds)
# Blend.
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
......
......@@ -16,15 +16,18 @@ from megatron.data.dataset_utils import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def build_train_valid_test_datasets(data_prefix, data_impl,
splits_string, train_valid_test_num_samples,
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup,
train_data_prefix=None, valid_data_prefix=None,
test_data_prefix=None,):
train_data_prefix=None,
valid_data_prefix=None,
test_data_prefix=None,
return_doc_ids=False):
"""Build train, valid, and test datasets."""
if data_prefix:
print_rank_0("Single data path provided for train, valid & test")
# Single dataset.
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0],
......@@ -35,7 +38,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl,
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
......@@ -46,7 +49,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl,
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup)
seq_length, seed, skip_warmup,
return_doc_ids)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
......@@ -67,6 +71,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl,
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
else:
print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.")
......@@ -74,23 +79,69 @@ def build_train_valid_test_datasets(data_prefix, data_impl,
# Single dataset.
if train_data_prefix is not None:
train_dataset = build_dataset("train", train_data_prefix, data_impl,
train_valid_test_num_samples[0], seq_length, seed,
skip_warmup)
train_valid_test_num_samples[0],
seq_length, seed, skip_warmup)
if valid_data_prefix is not None:
valid_dataset = build_dataset("valid", valid_data_prefix, data_impl,
train_valid_test_num_samples[1], seq_length, seed,
False)
train_valid_test_num_samples[1],
seq_length, seed, False)
if test_data_prefix is not None:
test_dataset = build_dataset("test", test_data_prefix, data_impl,
train_valid_test_num_samples[2], seq_length, seed,
False)
train_valid_test_num_samples[2],
seq_length, seed, False)
return (train_dataset, valid_dataset, test_dataset)
def build_dataset(dataset_name, data_prefix, data_impl, num_samples, seq_length, seed, skip_warmup):
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup,
return_doc_ids=False):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = GPTDataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed,
return_doc_ids)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def build_dataset(dataset_name, data_prefix, data_impl, num_samples,
seq_length, seed, skip_warmup):
dataset = None
if len(data_prefix) == 1:
dataset = _build_dataset(dataset_name,
......@@ -119,7 +170,7 @@ def build_dataset(dataset_name, data_prefix, data_impl, num_samples, seq_length,
def _build_dataset(dataset_name, data_prefix, data_impl,
num_samples, seq_length, seed, skip_warmup):
num_samples, seq_length, seed, skip_warmup):
"""
Build dataset. This method is called when individual
train, valid, test datasets are provided
......@@ -146,49 +197,6 @@ def _build_dataset(dataset_name, data_prefix, data_impl,
return dataset
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = GPTDataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
"""Build indexed dataset."""
print_rank_0(' > building dataset index ...')
......@@ -208,19 +216,23 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
class GPTDataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed):
num_samples, seq_length, seed,
return_doc_ids=False):
self.name = name
self.indexed_dataset = indexed_dataset
self.return_doc_ids = return_doc_ids
# Checks
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]
# Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
self.name, data_prefix, documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed)
self.doc_idx, self.sample_idx, self.shuffle_idx, self.index_prefix = \
_build_index_mappings(self.name, data_prefix,
documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed)
def __len__(self):
# -1 is due to data structure used to retieve the index:
......@@ -236,24 +248,33 @@ class GPTDataset(torch.utils.data.Dataset):
offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx + 1][1]
# If we are within the same document, just extract the chunk.
doc_ids = []
if doc_index_f == doc_index_l:
doc_ids.append(self.doc_idx[doc_index_f])
sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f,
length=offset_l - offset_f + 1)
else:
# Otherwise, get the rest of the initial document.
doc_ids.append(self.doc_idx[doc_index_f])
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f)]
# Loop over all in between documents and add the entire document.
for i in range(doc_index_f + 1, doc_index_l):
doc_ids.append(self.doc_idx[i])
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document.
doc_ids.append(self.doc_idx[doc_index_l])
sample_list.append(self.indexed_dataset.get(
self.doc_idx[doc_index_l],
length=offset_l + 1))
sample = np.concatenate(sample_list)
return {'text': np.array(sample, dtype=np.int64)}
if self.return_doc_ids: # for retro preprocessing
return {'text': np.array(sample, dtype=np.int64),
'doc_ids': np.array(doc_ids, dtype=np.int64)}
else:
return {'text': np.array(sample, dtype=np.int64)}
def _build_index_mappings(name, data_prefix, documents, sizes,
......@@ -267,15 +288,16 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# Number of tokens in each epoch and number of required epochs.
tokens_per_epoch = _num_tokens(documents, sizes)
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
# rng state
np_rng = np.random.RandomState(seed=seed)
# Filename of the index mappings.
_filename = data_prefix
_filename += '_{}_indexmap'.format(name)
_filename += '_{}ns'.format(num_samples)
_filename += '_{}sl'.format(seq_length)
_filename += '_{}s'.format(seed)
index_prefix = '{}_indexmap'.format(name)
index_prefix += '_{}ns'.format(num_samples)
index_prefix += '_{}sl'.format(seq_length)
index_prefix += '_{}s'.format(seed)
_filename = data_prefix + '_' + index_prefix
doc_idx_filename = _filename + '_doc_idx.npy'
sample_idx_filename = _filename + '_sample_idx.npy'
shuffle_idx_filename = _filename + '_shuffle_idx.npy'
......@@ -343,8 +365,6 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch)
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save sample-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time))
......@@ -389,7 +409,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
sample_idx.shape[0]))
print_rank_0(' total number of epochs: {}'.format(num_epochs))
return doc_idx, sample_idx, shuffle_idx
return doc_idx, sample_idx, shuffle_idx, index_prefix
def _num_tokens(documents, sizes):
......@@ -481,7 +501,7 @@ def _build_shuffle_idx(num_samples, total_size, np_rng):
"""Build the range [0, size) and shuffle."""
print(' > building shuffle index with split [0, {}) and [{}, {}) '
'...'.format(num_samples, num_samples, total_size), flush=True)
dtype_ = np.uint32
if total_size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64
......
......@@ -460,7 +460,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
return self._path
def __setstate__(self, state):
self._do_init(state)
self._do_init(state, skip_warmup=True)
def _do_init(self, path, skip_warmup):
self._path = path
......
......@@ -12,6 +12,7 @@ from .microbatches import build_num_microbatches_calculator
from .timers import Timers
_GLOBAL_ARGS = None
_GLOBAL_RETRO_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
......@@ -25,6 +26,11 @@ def get_args():
return _GLOBAL_ARGS
def get_retro_args():
"""Return retro arguments."""
return _GLOBAL_RETRO_ARGS
def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
......@@ -98,6 +104,11 @@ def set_args(args):
_GLOBAL_ARGS = args
def set_retro_args(retro_args):
global _GLOBAL_RETRO_ARGS
_GLOBAL_RETRO_ARGS = retro_args
def _build_num_microbatches_calculator(args):
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
......
......@@ -174,7 +174,7 @@ def _initialize_distributed():
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
timeout=timedelta(minutes=10))
timeout=timedelta(minutes=args.distributed_timeout_minutes))
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
......
......@@ -16,6 +16,7 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
......@@ -137,6 +138,10 @@ class BertModel(MegatronModule):
self.pre_process = pre_process
self.post_process = post_process
self.return_embeddings = args.output_bert_embeddings
if self.return_embeddings:
assert self.post_process and self.add_binary_head
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
......@@ -182,6 +187,24 @@ class BertModel(MegatronModule):
if self.post_process and self.add_binary_head:
lm_output, pooled_output = lm_output
# Return pooled output (e.g., when computing Bert embeddings).
if self.return_embeddings:
# Sum attention mask.
embeddings = torch.transpose(lm_output, 0, 1)
masks = torch.sum(attention_mask, dim=1)
# Collect masked embeddings.
output = torch.zeros(
size=(embeddings.shape[0], embeddings.shape[2]),
dtype=torch.float32,
device=torch.cuda.current_device())
for i, (embedding, mask) in enumerate(zip(embeddings, masks)):
output[i, :] = torch.mean(embedding[1: mask - 1], dim=0)
return output
else:
pooled_output = None
......
......@@ -74,13 +74,17 @@ class GPTModel(MegatronModule):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, inference_params=None):
def forward(self, input_ids, position_ids, attention_mask,
ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
labels=None, tokentype_ids=None, inference_params=None):
lm_output = self.language_model(
input_ids,
position_ids,
attention_mask,
ret_input_ids=ret_input_ids,
ret_position_ids=ret_position_ids,
ret_attn_mask=ret_attn_mask,
inference_params=inference_params)
if self.post_process:
......
......@@ -7,11 +7,13 @@ import torch.nn.functional as F
from megatron import get_args
from megatron.core import mpu, tensor_parallel
from .enums import LayerType, AttnMaskType
from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal
from .retro_transformer import ParallelRetroEncoder, ParallelRetroTransformer
from .transformer import ParallelTransformer
from .utils import get_linear_layer
from .utils import init_method_normal, scaled_init_method_normal
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
......@@ -349,17 +351,39 @@ class TransformerLanguageModel(MegatronModule):
self.num_tokentypes)
self._embedding_key = 'embedding'
# Transformer.
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if self.add_encoder:
self.encoder = ParallelTransformer(
# Retriever (bi-directional transformer with cross attention)
if args.retro_add_retriever:
self.retriever = ParallelRetroEncoder(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
self_attn_mask_type=AttnMaskType.padding,
pre_process=self.pre_process,
post_process=self.post_process
post_process=False,
)
self._retriever_key = 'retriever'
else:
self.retriever = None
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if self.add_encoder:
if args.retro_add_retriever:
self.encoder = ParallelRetroTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
retriever=self.retriever,
)
else:
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._encoder_key = 'encoder'
else:
self.encoder = None
......@@ -414,11 +438,19 @@ class TransformerLanguageModel(MegatronModule):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None,
inference_params=None,
pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
# Retriever embedding.
if self.retriever and self.pre_process:
retriever_input = self.embedding(ret_input_ids, ret_position_ids,
tokentype_ids=tokentype_ids)
else:
retriever_input = None
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(enc_input_ids, enc_position_ids,
......@@ -429,10 +461,18 @@ class TransformerLanguageModel(MegatronModule):
# Run encoder.
if enc_hidden_states is None:
if self.encoder is not None:
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
inference_params=inference_params)
if self.retriever:
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
retriever_output=retriever_input,
retriever_attn_mask=ret_attn_mask,
inference_params=inference_params)
else:
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
inference_params=inference_params)
else:
encoder_output = self.encoder_hidden_state
else:
......
This diff is collapsed.
......@@ -270,7 +270,6 @@ class CoreAttention(MegatronModule):
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
......@@ -477,6 +476,7 @@ class ParallelAttention(MegatronModule):
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
......
......@@ -244,6 +244,7 @@ def forward_backward_no_pipelining(forward_step_func,
output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
timers, collect_non_loss_data)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment