Commit 762072e0 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/retro' into 'main'

Retro

See merge request ADLR/megatron-lm!489
parents ef59b687 17a6044d
...@@ -3,3 +3,4 @@ __pycache__ ...@@ -3,3 +3,4 @@ __pycache__
build build
.coverage_* .coverage_*
*.egg-info *.egg-info
*~
...@@ -42,7 +42,9 @@ The following table shows both model (MFU) and hardware (HFU) FLOPs utilization ...@@ -42,7 +42,9 @@ The following table shows both model (MFU) and hardware (HFU) FLOPs utilization
* [Distributed Pretraining](#distributed-pretraining) * [Distributed Pretraining](#distributed-pretraining)
* [Activation Checkpointing and Recomputation](#activation-checkpointing-and-recomputation) * [Activation Checkpointing and Recomputation](#activation-checkpointing-and-recomputation)
* [Distributed Optimizer](#distributed-optimizer) * [Distributed Optimizer](#distributed-optimizer)
* [FlashAttention](#flashattention)
* [GPT-3 Example](#gpt-3-example) * [GPT-3 Example](#gpt-3-example)
* [Retro](#retro)
* [Evaluation and Tasks](#evaluation-and-tasks) * [Evaluation and Tasks](#evaluation-and-tasks)
* [GPT Text Generation](#gpt-text-generation) * [GPT Text Generation](#gpt-text-generation)
* [GPT Evaluation](#gpt-evaluation) * [GPT Evaluation](#gpt-evaluation)
...@@ -323,6 +325,19 @@ In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to config ...@@ -323,6 +325,19 @@ In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to config
With full global batch size of 1536 on 1024 A100 GPUs, each iteration takes around 32 seconds resulting in 138 teraFLOPs per GPU which is 44% of the theoretical peak FLOPs. With full global batch size of 1536 on 1024 A100 GPUs, each iteration takes around 32 seconds resulting in 138 teraFLOPs per GPU which is 44% of the theoretical peak FLOPs.
## Retro
See:
- `tools/retro/README.md` for an overview.
- `tools/retro/examples/get_preprocess_cmd.sh` for an example of common preprocessing arguments.
- `tools/retro/examples/preprocess_data.sh` for an example of how to preprocess data.
- `tools/retro/examples/pretrain_model.sh` for an example of how to pretrain a model.
Retro is a retrieval-enhanced model that is based on GPT. As described in [Improving language models by retrieving from trillions of tokens](https://arxiv.org/abs/2112.04426), Retro retrieves from a database of document chunks by performing locality search using a sample's tokens. The retrieval database can be large -- often billions or even trillions of tokens -- and provides a more efficient storage mechanism of factual knowledge, when compared to storing factual knowledge implicitly within the network's parameters.
Using Retro requires two steps: 1) preprocessing the retrieval database and pretraining neighbors, and 2) pretraining a model using this data. Please see `tools/retro/README.md` for a detailed overview.
<!-- <!--
## REALM Pipeline ## REALM Pipeline
We are working on implementing the [REALM](https://arxiv.org/pdf/2002.08909.pdf) system. The following sections (will) reflect the three stages of training it. For now it's just the ICT code. We are working on implementing the [REALM](https://arxiv.org/pdf/2002.08909.pdf) system. The following sections (will) reflect the three stages of training it. For now it's just the ICT code.
......
...@@ -11,10 +11,10 @@ import sys ...@@ -11,10 +11,10 @@ import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir))) os.path.pardir, os.path.pardir)))
from megatron import get_args from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import print_rank_0
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType from megatron.model import GPTModel, ModelType
......
...@@ -10,10 +10,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ...@@ -10,10 +10,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir, os.path.pardir))) os.path.pardir, os.path.pardir)))
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.core import mpu
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import GPTModel from megatron.model import GPTModel
from megatron.training import get_model from megatron.training import get_model
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch import torch
from .global_vars import get_args from .global_vars import get_args, get_retro_args
from .global_vars import get_current_global_batch_size from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler from .global_vars import get_signal_handler
......
...@@ -3,9 +3,14 @@ ...@@ -3,9 +3,14 @@
"""Megatron arguments.""" """Megatron arguments."""
import argparse import argparse
import json
import os import os
import torch import torch
import types
from megatron.global_vars import set_retro_args, get_retro_args
from tools.retro.utils import get_args_path as get_retro_args_path
def parse_args(extra_args_provider=None, ignore_unknown_args=False): def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments.""" """Parse all arguments."""
...@@ -29,6 +34,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): ...@@ -29,6 +34,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_logging_args(parser) parser = _add_logging_args(parser)
parser = _add_inference_args(parser) parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser) parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -333,7 +339,6 @@ def validate_args(args, defaults={}): ...@@ -333,7 +339,6 @@ def validate_args(args, defaults={}):
if args.sequence_parallel: if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False args.async_tensor_model_parallel_allreduce = False
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel: if args.sequence_parallel:
raise RuntimeError( raise RuntimeError(
...@@ -344,15 +349,31 @@ def validate_args(args, defaults={}): ...@@ -344,15 +349,31 @@ def validate_args(args, defaults={}):
"Using async gradient all reduce requires setting the environment " "Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1") "variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
# Load retro args.
if args.retro_workdir:
retro_args_path = get_retro_args_path(args.retro_workdir)
if os.path.exists(retro_args_path):
with open(retro_args_path) as f:
retro_args = types.SimpleNamespace(**json.load(f))
retro_args.retro_return_doc_ids = args.retro_return_doc_ids
retro_args.retro_gpt_retrieved_length = \
args.retro_num_retrieved_chunks * \
retro_args.retro_gpt_chunk_length
set_retro_args(retro_args)
# Print arguments.
_print_args("arguments", args)
retro_args = get_retro_args()
if retro_args and args != retro_args:
_print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank))
_print_args(args)
return args return args
def _print_args(args): def _print_args(title, args):
"""Print arguments.""" """Print arguments."""
if args.rank == 0: if args.rank == 0:
print('------------------------ arguments ------------------------', print(f'------------------------ {title} ------------------------',
flush=True) flush=True)
str_list = [] str_list = []
for arg in vars(args): for arg in vars(args):
...@@ -360,7 +381,7 @@ def _print_args(args): ...@@ -360,7 +381,7 @@ def _print_args(args):
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()): for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True) print(arg, flush=True)
print('-------------------- end of arguments ---------------------', print(f'-------------------- end of {title} ---------------------',
flush=True) flush=True)
...@@ -403,12 +424,64 @@ def _add_inference_args(parser): ...@@ -403,12 +424,64 @@ def _add_inference_args(parser):
help='During inference, if batch-size times ' help='During inference, if batch-size times '
'sequence-length is smaller than this threshold ' 'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.') 'then we will not use pipelining, otherwise we will.')
group.add_argument('--max-tokens-to-oom', group.add_argument('--max-tokens-to-oom',
type=int, default=12000, type=int, default=12000,
help='Maximum number of tokens during inference' help='Maximum number of tokens during inference'
'tokens here is # in prompt + # to generate' 'tokens here is # in prompt + # to generate'
'Allows us to throw an error before OOM crashes server') 'Allows us to throw an error before OOM crashes server')
group.add_argument('--output-bert-embeddings', action='store_true',
help='Output Bert embeddings (via mean pooling) from '
'model, rather than its binary head output or entire '
'hidden batch.')
group.add_argument('--bert-embedder-type', default="megatron",
choices=["megatron", "huggingface"],
help='Select either Megatron or Huggingface as the '
'Bert embedder.')
return parser
def _add_retro_args(parser):
group = parser.add_argument_group(title='retro')
group.add_argument('--retro-workdir', default=None,
help='Retro working directory, which contains the '
'preprocessed data for for pretraining. This directory '
'is built during preprocessing (see '
'tools/retro/README.md), and contains subdirectories '
'for the chunk database and pretraining neighbors.')
group.add_argument('--retro-add-retriever',
action='store_true', default=False,
help='Add a retriever to the transformer, for use in '
'pretraining a Retro model.')
group.add_argument('--retro-cyclic-train-iters', type=int, default=None,
help='Set number of training iterations for cyclic '
'Retro training.')
group.add_argument('--retro-encoder-layers', type=int, default=2,
help='Number of layers to use for the retrieval '
'encoder.')
group.add_argument('--retro-encoder-hidden-dropout',
type=float, default=0.1, help='Hidden dropout for '
'retrieval encoder.')
group.add_argument('--retro-encoder-attention-dropout',
type=float, default=0.1, help='Attention dropout for '
'retrieval encoder.')
group.add_argument("--retro-num-neighbors", type=int, default=2,
help='Number of neighbors to retrieve during '
'pretraining.')
group.add_argument("--retro-num-retrieved-chunks", type=int, default=2,
help='Number of chunks to retrieve from the retrieval '
'database.')
group.add_argument("--retro-return-doc-ids", action="store_true",
help="Turn this on when preprocessing retro data.")
# Enforce argument naming convention.
for action in group._group_actions:
prefix = action.dest.split("_")[0]
assert prefix == "retro", \
"Retro args must be prefixed with '--retro-*', for consistent " \
"styling. Please fix '%s'." % ", ".join(action.option_strings)
return parser return parser
...@@ -775,6 +848,10 @@ def _add_checkpointing_args(parser): ...@@ -775,6 +848,10 @@ def _add_checkpointing_args(parser):
group.add_argument('--use-checkpoint-args', action='store_true', group.add_argument('--use-checkpoint-args', action='store_true',
help='Override any command line arguments with arguments ' help='Override any command line arguments with arguments '
'from the checkpoint') 'from the checkpoint')
group.add_argument('--exit-on-missing-checkpoint', action='store_true',
help="If '--load' is set, but checkpoint is not found "
"(e.g., path typo), then exit instead of random "
"initialization.")
return parser return parser
...@@ -835,6 +912,8 @@ def _add_distributed_args(parser): ...@@ -835,6 +912,8 @@ def _add_distributed_args(parser):
group.add_argument('--distributed-backend', default='nccl', group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'], choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.') help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
group.add_argument('--DDP-impl', default='local', group.add_argument('--DDP-impl', default='local',
choices=['local', 'torch'], choices=['local', 'torch'],
help='which DistributedDataParallel implementation ' help='which DistributedDataParallel implementation '
......
...@@ -508,7 +508,16 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -508,7 +508,16 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
use_distributed_optimizer=args.use_distributed_optimizer, use_distributed_optimizer=args.use_distributed_optimizer,
rank0=False) rank0=False)
# Checkpoint not loaded.
if model_state_dict is None: if model_state_dict is None:
# Conditionally exit at this point.
if args.exit_on_missing_checkpoint:
print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<")
torch.distributed.barrier()
sys.exit()
# Iteration defaults to 0.
return 0 return 0
# set checkpoint version # set checkpoint version
......
...@@ -19,7 +19,6 @@ from megatron.data.dataset_utils import ( ...@@ -19,7 +19,6 @@ from megatron.data.dataset_utils import (
create_masked_lm_predictions create_masked_lm_predictions
) )
class BertDataset(torch.utils.data.Dataset): class BertDataset(torch.utils.data.Dataset):
def __init__(self, name, indexed_dataset, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
...@@ -156,7 +155,9 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, ...@@ -156,7 +155,9 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
# Some checks. # Some checks.
num_tokens = len(tokens) num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens padding_length = max_seq_length - num_tokens
assert padding_length >= 0 assert padding_length >= 0, \
f"num_tokens ({num_tokens}) is greater than " \
"max_seq_length ({max_seq_length})."
assert len(tokentypes) == num_tokens assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels) assert len(masked_positions) == len(masked_labels)
......
...@@ -50,4 +50,7 @@ class BlendableDataset(torch.utils.data.Dataset): ...@@ -50,4 +50,7 @@ class BlendableDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
dataset_idx = self.dataset_index[idx] dataset_idx = self.dataset_index[idx]
sample_idx = self.dataset_sample_index[idx] sample_idx = self.dataset_sample_index[idx]
return self.datasets[dataset_idx][sample_idx] return {
"dataset_idx" : dataset_idx,
**self.datasets[dataset_idx][sample_idx],
}
...@@ -452,7 +452,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -452,7 +452,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string, prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob, max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, dataset_type=dataset_type) seed, skip_warmup, binary_head, max_seq_length_dec,
dataset_type=dataset_type)
if train_ds: if train_ds:
train_datasets.append(train_ds) train_datasets.append(train_ds)
if valid_ds: if valid_ds:
......
...@@ -16,15 +16,18 @@ from megatron.data.dataset_utils import get_train_valid_test_split_ ...@@ -16,15 +16,18 @@ from megatron.data.dataset_utils import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def build_train_valid_test_datasets(data_prefix, data_impl, def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
splits_string, train_valid_test_num_samples, train_valid_test_num_samples,
seq_length, seed, skip_warmup, seq_length, seed, skip_warmup,
train_data_prefix=None, valid_data_prefix=None, train_data_prefix=None,
test_data_prefix=None,): valid_data_prefix=None,
test_data_prefix=None,
return_doc_ids=False):
"""Build train, valid, and test datasets.""" """Build train, valid, and test datasets."""
if data_prefix: if data_prefix:
print_rank_0("Single data path provided for train, valid & test") print_rank_0("Single data path provided for train, valid & test")
# Single dataset. # Single dataset.
if len(data_prefix) == 1: if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0], return _build_train_valid_test_datasets(data_prefix[0],
...@@ -46,7 +49,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, ...@@ -46,7 +49,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl,
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string, prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup) seq_length, seed, skip_warmup,
return_doc_ids)
if train_ds: if train_ds:
train_datasets.append(train_ds) train_datasets.append(train_ds)
if valid_ds: if valid_ds:
...@@ -67,6 +71,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, ...@@ -67,6 +71,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl,
return (blending_train_dataset, blending_valid_dataset, return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset) blending_test_dataset)
else: else:
print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.") print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.")
...@@ -74,23 +79,69 @@ def build_train_valid_test_datasets(data_prefix, data_impl, ...@@ -74,23 +79,69 @@ def build_train_valid_test_datasets(data_prefix, data_impl,
# Single dataset. # Single dataset.
if train_data_prefix is not None: if train_data_prefix is not None:
train_dataset = build_dataset("train", train_data_prefix, data_impl, train_dataset = build_dataset("train", train_data_prefix, data_impl,
train_valid_test_num_samples[0], seq_length, seed, train_valid_test_num_samples[0],
skip_warmup) seq_length, seed, skip_warmup)
if valid_data_prefix is not None: if valid_data_prefix is not None:
valid_dataset = build_dataset("valid", valid_data_prefix, data_impl, valid_dataset = build_dataset("valid", valid_data_prefix, data_impl,
train_valid_test_num_samples[1], seq_length, seed, train_valid_test_num_samples[1],
False) seq_length, seed, False)
if test_data_prefix is not None: if test_data_prefix is not None:
test_dataset = build_dataset("test", test_data_prefix, data_impl, test_dataset = build_dataset("test", test_data_prefix, data_impl,
train_valid_test_num_samples[2], seq_length, seed, train_valid_test_num_samples[2],
False) seq_length, seed, False)
return (train_dataset, valid_dataset, test_dataset)
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup,
return_doc_ids=False):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = GPTDataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed,
return_doc_ids)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset) return (train_dataset, valid_dataset, test_dataset)
def build_dataset(dataset_name, data_prefix, data_impl, num_samples, seq_length, seed, skip_warmup): def build_dataset(dataset_name, data_prefix, data_impl, num_samples,
seq_length, seed, skip_warmup):
dataset = None dataset = None
if len(data_prefix) == 1: if len(data_prefix) == 1:
dataset = _build_dataset(dataset_name, dataset = _build_dataset(dataset_name,
...@@ -146,49 +197,6 @@ def _build_dataset(dataset_name, data_prefix, data_impl, ...@@ -146,49 +197,6 @@ def _build_dataset(dataset_name, data_prefix, data_impl,
return dataset return dataset
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = GPTDataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
"""Build indexed dataset.""" """Build indexed dataset."""
print_rank_0(' > building dataset index ...') print_rank_0(' > building dataset index ...')
...@@ -208,20 +216,24 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): ...@@ -208,20 +216,24 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
class GPTDataset(torch.utils.data.Dataset): class GPTDataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, documents, indexed_dataset, def __init__(self, name, data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed): num_samples, seq_length, seed,
return_doc_ids=False):
self.name = name self.name = name
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
self.return_doc_ids = return_doc_ids
# Checks # Checks
assert np.min(documents) >= 0 assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0] assert np.max(documents) < indexed_dataset.sizes.shape[0]
# Build index mappings. # Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( self.doc_idx, self.sample_idx, self.shuffle_idx, self.index_prefix = \
self.name, data_prefix, documents, self.indexed_dataset.sizes, _build_index_mappings(self.name, data_prefix,
documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed) num_samples, seq_length, seed)
def __len__(self): def __len__(self):
# -1 is due to data structure used to retieve the index: # -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1]) # sample i --> [sample_idx[i], sample_idx[i+1])
...@@ -236,23 +248,32 @@ class GPTDataset(torch.utils.data.Dataset): ...@@ -236,23 +248,32 @@ class GPTDataset(torch.utils.data.Dataset):
offset_f = self.sample_idx[idx][1] offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx + 1][1] offset_l = self.sample_idx[idx + 1][1]
# If we are within the same document, just extract the chunk. # If we are within the same document, just extract the chunk.
doc_ids = []
if doc_index_f == doc_index_l: if doc_index_f == doc_index_l:
doc_ids.append(self.doc_idx[doc_index_f])
sample = self.indexed_dataset.get(self.doc_idx[doc_index_f], sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f, offset=offset_f,
length=offset_l - offset_f + 1) length=offset_l - offset_f + 1)
else: else:
# Otherwise, get the rest of the initial document. # Otherwise, get the rest of the initial document.
doc_ids.append(self.doc_idx[doc_index_f])
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f)] offset=offset_f)]
# Loop over all in between documents and add the entire document. # Loop over all in between documents and add the entire document.
for i in range(doc_index_f + 1, doc_index_l): for i in range(doc_index_f + 1, doc_index_l):
doc_ids.append(self.doc_idx[i])
sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document. # And finally add the relevant portion of last document.
doc_ids.append(self.doc_idx[doc_index_l])
sample_list.append(self.indexed_dataset.get( sample_list.append(self.indexed_dataset.get(
self.doc_idx[doc_index_l], self.doc_idx[doc_index_l],
length=offset_l + 1)) length=offset_l + 1))
sample = np.concatenate(sample_list) sample = np.concatenate(sample_list)
if self.return_doc_ids: # for retro preprocessing
return {'text': np.array(sample, dtype=np.int64),
'doc_ids': np.array(doc_ids, dtype=np.int64)}
else:
return {'text': np.array(sample, dtype=np.int64)} return {'text': np.array(sample, dtype=np.int64)}
...@@ -267,15 +288,16 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -267,15 +288,16 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# Number of tokens in each epoch and number of required epochs. # Number of tokens in each epoch and number of required epochs.
tokens_per_epoch = _num_tokens(documents, sizes) tokens_per_epoch = _num_tokens(documents, sizes)
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
# rng state # rng state
np_rng = np.random.RandomState(seed=seed) np_rng = np.random.RandomState(seed=seed)
# Filename of the index mappings. # Filename of the index mappings.
_filename = data_prefix index_prefix = '{}_indexmap'.format(name)
_filename += '_{}_indexmap'.format(name) index_prefix += '_{}ns'.format(num_samples)
_filename += '_{}ns'.format(num_samples) index_prefix += '_{}sl'.format(seq_length)
_filename += '_{}sl'.format(seq_length) index_prefix += '_{}s'.format(seed)
_filename += '_{}s'.format(seed) _filename = data_prefix + '_' + index_prefix
doc_idx_filename = _filename + '_doc_idx.npy' doc_idx_filename = _filename + '_doc_idx.npy'
sample_idx_filename = _filename + '_sample_idx.npy' sample_idx_filename = _filename + '_sample_idx.npy'
shuffle_idx_filename = _filename + '_shuffle_idx.npy' shuffle_idx_filename = _filename + '_shuffle_idx.npy'
...@@ -343,8 +365,6 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -343,8 +365,6 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
assert sizes.dtype == np.int32 assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch) num_epochs, tokens_per_epoch)
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True) np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save sample-idx mapping ' print_rank_0(' > elasped time to build and save sample-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time)) '(seconds): {:4f}'.format(time.time() - start_time))
...@@ -389,7 +409,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -389,7 +409,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
sample_idx.shape[0])) sample_idx.shape[0]))
print_rank_0(' total number of epochs: {}'.format(num_epochs)) print_rank_0(' total number of epochs: {}'.format(num_epochs))
return doc_idx, sample_idx, shuffle_idx return doc_idx, sample_idx, shuffle_idx, index_prefix
def _num_tokens(documents, sizes): def _num_tokens(documents, sizes):
......
...@@ -460,7 +460,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -460,7 +460,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
return self._path return self._path
def __setstate__(self, state): def __setstate__(self, state):
self._do_init(state) self._do_init(state, skip_warmup=True)
def _do_init(self, path, skip_warmup): def _do_init(self, path, skip_warmup):
self._path = path self._path = path
......
...@@ -12,6 +12,7 @@ from .microbatches import build_num_microbatches_calculator ...@@ -12,6 +12,7 @@ from .microbatches import build_num_microbatches_calculator
from .timers import Timers from .timers import Timers
_GLOBAL_ARGS = None _GLOBAL_ARGS = None
_GLOBAL_RETRO_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None _GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None _GLOBAL_TENSORBOARD_WRITER = None
...@@ -25,6 +26,11 @@ def get_args(): ...@@ -25,6 +26,11 @@ def get_args():
return _GLOBAL_ARGS return _GLOBAL_ARGS
def get_retro_args():
"""Return retro arguments."""
return _GLOBAL_RETRO_ARGS
def get_num_microbatches(): def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
...@@ -98,6 +104,11 @@ def set_args(args): ...@@ -98,6 +104,11 @@ def set_args(args):
_GLOBAL_ARGS = args _GLOBAL_ARGS = args
def set_retro_args(retro_args):
global _GLOBAL_RETRO_ARGS
_GLOBAL_RETRO_ARGS = retro_args
def _build_num_microbatches_calculator(args): def _build_num_microbatches_calculator(args):
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
......
...@@ -174,7 +174,7 @@ def _initialize_distributed(): ...@@ -174,7 +174,7 @@ def _initialize_distributed():
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.distributed_backend, backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank, world_size=args.world_size, rank=args.rank,
timeout=timedelta(minutes=10)) timeout=timedelta(minutes=args.distributed_timeout_minutes))
# Set the tensor model-parallel, pipeline model-parallel, and # Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators. # data-parallel communicators.
......
...@@ -16,6 +16,7 @@ from megatron.model.utils import init_method_normal ...@@ -16,6 +16,7 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
def bert_extended_attention_mask(attention_mask): def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s] # [b, 1, s]
...@@ -137,6 +138,10 @@ class BertModel(MegatronModule): ...@@ -137,6 +138,10 @@ class BertModel(MegatronModule):
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.return_embeddings = args.output_bert_embeddings
if self.return_embeddings:
assert self.post_process and self.add_binary_head
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std, scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers) args.num_layers)
...@@ -182,6 +187,24 @@ class BertModel(MegatronModule): ...@@ -182,6 +187,24 @@ class BertModel(MegatronModule):
if self.post_process and self.add_binary_head: if self.post_process and self.add_binary_head:
lm_output, pooled_output = lm_output lm_output, pooled_output = lm_output
# Return pooled output (e.g., when computing Bert embeddings).
if self.return_embeddings:
# Sum attention mask.
embeddings = torch.transpose(lm_output, 0, 1)
masks = torch.sum(attention_mask, dim=1)
# Collect masked embeddings.
output = torch.zeros(
size=(embeddings.shape[0], embeddings.shape[2]),
dtype=torch.float32,
device=torch.cuda.current_device())
for i, (embedding, mask) in enumerate(zip(embeddings, masks)):
output[i, :] = torch.mean(embedding[1: mask - 1], dim=0)
return output
else: else:
pooled_output = None pooled_output = None
......
...@@ -74,13 +74,17 @@ class GPTModel(MegatronModule): ...@@ -74,13 +74,17 @@ class GPTModel(MegatronModule):
"""See megatron.model.transformer.set_input_tensor()""" """See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, labels=None, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, inference_params=None): ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
labels=None, tokentype_ids=None, inference_params=None):
lm_output = self.language_model( lm_output = self.language_model(
input_ids, input_ids,
position_ids, position_ids,
attention_mask, attention_mask,
ret_input_ids=ret_input_ids,
ret_position_ids=ret_position_ids,
ret_attn_mask=ret_attn_mask,
inference_params=inference_params) inference_params=inference_params)
if self.post_process: if self.post_process:
......
...@@ -7,11 +7,13 @@ import torch.nn.functional as F ...@@ -7,11 +7,13 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from .enums import LayerType, AttnMaskType
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType from .retro_transformer import ParallelRetroEncoder, ParallelRetroTransformer
from megatron.model.transformer import ParallelTransformer from .transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer from .utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal from .utils import init_method_normal, scaled_init_method_normal
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...@@ -349,16 +351,38 @@ class TransformerLanguageModel(MegatronModule): ...@@ -349,16 +351,38 @@ class TransformerLanguageModel(MegatronModule):
self.num_tokentypes) self.num_tokentypes)
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Transformer. # Retriever (bi-directional transformer with cross attention)
if args.retro_add_retriever:
self.retriever = ParallelRetroEncoder(
self.init_method,
output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding,
pre_process=self.pre_process,
post_process=False,
)
self._retriever_key = 'retriever'
else:
self.retriever = None
# Encoder (usually set to True, False if part of an encoder-decoder # Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage). # architecture and in encoder-only stage).
if self.add_encoder: if self.add_encoder:
if args.retro_add_retriever:
self.encoder = ParallelRetroTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
retriever=self.retriever,
)
else:
self.encoder = ParallelTransformer( self.encoder = ParallelTransformer(
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type, self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process post_process=self.post_process,
) )
self._encoder_key = 'encoder' self._encoder_key = 'encoder'
else: else:
...@@ -414,11 +438,19 @@ class TransformerLanguageModel(MegatronModule): ...@@ -414,11 +438,19 @@ class TransformerLanguageModel(MegatronModule):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, enc_dec_attn_mask=None, tokentype_ids=None,
inference_params=None, inference_params=None,
pooling_sequence_index=0, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
# Retriever embedding.
if self.retriever and self.pre_process:
retriever_input = self.embedding(ret_input_ids, ret_position_ids,
tokentype_ids=tokentype_ids)
else:
retriever_input = None
# Encoder embedding. # Encoder embedding.
if self.pre_process: if self.pre_process:
encoder_input = self.embedding(enc_input_ids, enc_position_ids, encoder_input = self.embedding(enc_input_ids, enc_position_ids,
...@@ -429,6 +461,14 @@ class TransformerLanguageModel(MegatronModule): ...@@ -429,6 +461,14 @@ class TransformerLanguageModel(MegatronModule):
# Run encoder. # Run encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
if self.encoder is not None: if self.encoder is not None:
if self.retriever:
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
retriever_output=retriever_input,
retriever_attn_mask=ret_attn_mask,
inference_params=inference_params)
else:
encoder_output = self.encoder( encoder_output = self.encoder(
encoder_input, encoder_input,
enc_attn_mask, enc_attn_mask,
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Retro Transformer.
** Special note about this file **
Many classes and methods in this file directly parallel those in transformer.py
in name and utility. However, due to 1) subtle changes in the code over time
(i.e., transposes and contexts), and 2) other code that is soon to be merged,
this file will *temporarily* remain as is, until a larger integration is
complete.
"""
import math
import numpy as np
import torch
import torch.nn.functional as F
from megatron import get_args, get_retro_args, get_tensorboard_writer
from megatron.core import parallel_state
from megatron.core import tensor_parallel
from megatron.core import utils as core_utils
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, init_method_normal
from .module import MegatronModule
from .transformer import _get_num_layers, ParallelMLP, NoopTransformerLayer
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
"""
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
*Note: differs from transformer.py/DropPath in hidden_state transpose.
"""
def __init__(self, drop_prob=0.):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, hidden_state):
if self.drop_prob == 0. or not self.training:
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize
output = hidden_state.div(keep_prob) * random_tensor
return output
class SwitchMLP(MegatronModule):
"""
Routes input to one of N MLP "experts"
"""
def __init__(self, init_method, output_layer_init_method):
super(SwitchMLP, self).__init__()
args = get_args()
self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
self.experts = torch.nn.ModuleList()
for i in range(args.num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states):
# hidden_states: [b, s, h]
b = hidden_states.size(0)
s = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) # [b s 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Each vector could be routed differently
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1]
max_ind = max_ind.view(-1) # [b*s]
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:]
output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output)
output_total[local_indices,:] = output
output_bias_total[local_indices,:] = output_bias
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(b, s, h)
output_bias_total = output_bias_total.view(b, s, h)
return output_total, output_bias_total
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = core_utils.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = core_utils.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = core_utils.divide(
args.num_attention_heads, world_size)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method)
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
# Output.
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = tensor_parallel \
.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = tensor_parallel \
.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
inference_value_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[
:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
return output, bias
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelRetroTransformerEncoderLayer(MegatronModule):
"""A single transformer layer for Retro Decoder with an retriever encoder inside and cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0., retriever=None):
args = get_args()
super(ParallelRetroTransformerEncoderLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Retro Encoder
self.retriever = retriever
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \
else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
retriever_output, retriever_attn_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention. # [ns, bs, d]
layernorm_output = self.post_attention_layernorm(layernorm_input)
"""
notations:
l: number of chunks
m: number of token per chunk
bs: batch size
d: hidden size
k: number of neighbors
r: number of tokens per neighbors (neighbors + continuation)
"""
args = get_args()
retro_args = get_retro_args()
chunk_length = retro_args.retro_gpt_chunk_length
retrieved_length = retro_args.retro_gpt_retrieved_length
num_neighbors = args.retro_num_neighbors
ns, bs, d = layernorm_output.shape
l = int(np.ceil(ns / chunk_length))
first_ns = ns % chunk_length
if first_ns > 0:
first_chunk, rest_chunk = \
layernorm_output[:first_ns], layernorm_output[first_ns:]
first_chunk = torch.nn.functional.pad(
first_chunk,
(0, 0, 0, 0, 0, chunk_length - first_ns),
'constant',
0)
chunked_output = \
torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d]
else:
chunked_output = layernorm_output # [l * m, bs, d]
chunked_output = chunked_output \
.reshape(l, chunk_length, bs, d) \
.permute(1, 2, 0, 3) \
.reshape(chunk_length, bs * l, d) \
.contiguous()
# Get Encoder Output
retriever_output = self.retriever(
retriever_output,
retriever_attn_mask,
retriever_output=chunked_output,
retriever_attn_mask=retriever_attn_mask,
inference_params=inference_params) # [r, k * bs * l , d]
retriever_output = retriever_output.reshape(
retrieved_length * num_neighbors, bs * l, d) # [r * k, bs * l, d]
# Chunked Cross attention with Retriever Encoder
pad = (ns - 1) % chunk_length
attending_chunks = layernorm_output[pad:] # [ns - m + 1, bs, d]
padded_chunks = torch.nn.functional.pad(
attending_chunks,
(0, 0, 0, 0, 0, chunk_length-1),
'constant', 0) # [ns, bs, d]
padded_chunked_output = padded_chunks \
.reshape(l, chunk_length, bs, d) \
.permute(1, 2, 0, 3)
padded_chunked_output = padded_chunked_output.reshape(
chunk_length, bs * l, d).contiguous() # [m, bs * l, d]
# attention_output: [m, bs * l, d]
# attention_bias: [d]
attention_output, attention_bias = \
self.inter_attention(
padded_chunked_output, # Q: main model embedding
None,
encoder_output=retriever_output) # KV: retriever output embedding
# Residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(attention_output),
torch.zeros_like(attention_output),
self.hidden_dropout)
layernorm_input = layernorm_input \
.reshape(chunk_length, bs, l, d) \
.permute(2, 0, 1, 3) # [l, m, bs, d]
layernorm_input = layernorm_input.reshape(chunk_length * l, bs, d)
layernorm_input = torch.nn.functional.pad(
layernorm_input,
(0, 0, 0, 0, pad, 0),
'constant', 0)[:ns] # [ns, b, d]
layernorm_input = layernorm_input + residual
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output, retriever_output
class ParallelRetroTransformerLayer(MegatronModule):
"""A single transformer layer for Retro Decoder with cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args()
super(ParallelRetroTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \
else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
retriever_output, retriever_attn_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
args = get_args()
retro_args = get_retro_args()
chunk_length = retro_args.retro_gpt_chunk_length
ns, bs, d = layernorm_output.shape
l = int(np.ceil(ns / chunk_length))
pad = (ns - 1) % chunk_length
attending_chunks = layernorm_output[pad:]
padded_chunks = torch.nn.functional.pad(
attending_chunks,
(0, 0, 0, 0, 0, chunk_length - 1),
'constant', 0)
padded_chunked_output = padded_chunks \
.reshape(l, chunk_length, bs, d) \
.permute(1, 2, 0, 3)
padded_chunked_output = padded_chunked_output.reshape(
chunk_length, bs * l, d).contiguous()
# Encoder output.
attention_output, attention_bias = \
self.inter_attention(padded_chunked_output,
None,
encoder_output=retriever_output)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(attention_output),
torch.zeros_like(attention_output),
self.hidden_dropout)
layernorm_input = layernorm_input \
.reshape(chunk_length, bs, l, d) \
.permute(2, 0, 1, 3) # [l, m, bs, d]
layernorm_input = layernorm_input.reshape(chunk_length * l, bs, d)
layernorm_input = torch.nn.functional.pad(
layernorm_input,
(0, 0, 0, 0, pad, 0),
'constant', 0)[:ns]
layernorm_input = layernorm_input + residual
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output
class ParallelRetroEncoderTransformerCALayer(MegatronModule):
"""A single transformer layer for Retro Encoder with cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args()
super(ParallelRetroEncoderTransformerCALayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.self_attention.attention_dropout = \
torch.nn.Dropout(args.retro_encoder_attention_dropout)
self.hidden_dropout = args.retro_encoder_hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \
else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
retriever_output, retriever_attn_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# Neighbors.
args = get_args()
retro_args = get_retro_args()
retrieved_length = retro_args.retro_gpt_retrieved_length
num_neighbors = args.retro_num_neighbors
ns, bs, d = layernorm_output.shape # [r, bs * l * k, d]
chunked_outputs = layernorm_output.reshape(retrieved_length, -1,
num_neighbors, d)
chunked_outputs_before_layer_norm = \
layernorm_input.reshape(retrieved_length, -1,
num_neighbors, d) # [r, bs * l, k, d]
layernorm_inputs = []
layernorm_outputs = []
for k in range(num_neighbors):
chunked_output = chunked_outputs[:,:,k].contiguous()
attention_output, attention_bias = \
self.inter_attention(
chunked_output, # Q (neighbor embedding)
None,
encoder_output=retriever_output) # K, V (hidden act)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = chunked_output
else:
residual = chunked_outputs_before_layer_norm[:,:,k]
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
layernorm_inputs.append(layernorm_input)
# Layer norm post the decoder attention
layernorm_output = \
self.post_inter_attention_layernorm(layernorm_input)
layernorm_outputs.append(layernorm_output)
# layernorm_input : [r, k * bs * l, d]
# layernorm_output : [r, k * bs * l, d]
layernorm_input = \
torch.stack(layernorm_inputs, dim=1).reshape(ns, bs, d)
layernorm_output = \
torch.stack(layernorm_outputs, dim=1).reshape(ns, bs, d)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \
else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output
class ParallelRetroEncoder(MegatronModule):
""" Retro Transformer class for encoder ."""
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True,
drop_path_rate=0.0):
super(ParallelRetroEncoder, self).__init__()
args = get_args()
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.distribute_saved_activations = \
args.distribute_saved_activations and not args.sequence_parallel
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = args.retro_encoder_layers
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
if args.retro_add_retriever:
self.P = [1]
# Transformer layers.
assert args.retro_add_retriever
def build_layer(layer_number):
if layer_number in self.P:
return ParallelRetroEncoderTransformerCALayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
else:
layer = ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
layer.self_attention.attention_dropout = \
torch.nn.Dropout(args.retro_encoder_attention_dropout)
layer.hidden_dropout = args.retro_encoder_hidden_dropout
return layer
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in
# the stage, divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an
# assignment of layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an
# assignment of layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \
parallel_state.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
else:
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process:
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and
# checkpoint the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = parallel_state.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = parallel_state.checkpoint(
custom(l, l + 1),
self.distribute_checkpointed_activations,
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask,
retriever_output, retriever_attn_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# Checks.
if inference_params:
assert self.activations_checkpoint_method is None, \
'inference does not work with activation checkpointing'
if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = core_utils.make_viewless_tensor(
hidden_states,
requires_grad = True,
keep_graph = True,
)
# Transpose encoder output.
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
args = get_args()
assert not args.sequence_parallel, "if SP, need rng context."
# Forward pass.
if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
if index + 1 in self.P:
hidden_states = layer(
hidden_states,
attention_mask,
retriever_output=retriever_output,
retriever_attn_mask=retriever_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
else:
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# Final layer norm.
if self.post_process:
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
return output
class ParallelRetroTransformer(MegatronModule):
"""Standard GPT Transformer class."""
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True,
drop_path_rate=0.0, retriever=None):
super(ParallelRetroTransformer, self).__init__()
args = get_args()
assert pre_process and post_process, \
"pipeline parallelism un-supported."
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.distribute_saved_activations = \
args.distribute_saved_activations and not args.sequence_parallel
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = _get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
if args.retro_add_retriever:
if args.num_layers == 12:
self.P = [6, 9, 12]
elif args.num_layers == 24:
self.P = np.arange(9, 25, 3).tolist()
elif args.num_layers == 40:
self.P = np.arange(9, 41, 3).tolist()
self.P.append(40)
self.retriever = retriever
# Transformer layers.
assert args.retro_add_retriever
def build_layer(layer_number):
if layer_number == min(self.P):
return ParallelRetroTransformerEncoderLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1],
retriever=retriever
)
elif layer_number in self.P:
return ParallelRetroTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
else:
return ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \
parallel_state.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
else:
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process:
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = parallel_state.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = parallel_state.checkpoint(
custom(l, l + 1),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask,
retriever_output=None, retriever_attn_mask=None,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# Checks.
if inference_params:
assert self.recompute_granularity is None, \
'inference does not work with activation checkpointing'
args = get_args()
# Transpose retriever output, to match hidden_states shape.
retriever_output = retriever_output.transpose(0, 1).contiguous()
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = core_utils.make_viewless_tensor(
hidden_states,
requires_grad=True,
keep_graph=True,
)
# Transpose encoder output.
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
# Forward pass.
assert not args.sequence_parallel, "if SP, need rng context."
if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
if args.retro_add_retriever and index + 1 == min(self.P):
hidden_states, E = layer(
hidden_states,
attention_mask,
retriever_output=retriever_output,
retriever_attn_mask=retriever_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
elif args.retro_add_retriever and index + 1 in self.P:
hidden_states = layer(
hidden_states,
attention_mask,
retriever_output=E,
retriever_attn_mask=retriever_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
else:
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# Final layer norm.
output = self.final_layernorm(hidden_states)
return output
...@@ -270,7 +270,6 @@ class CoreAttention(MegatronModule): ...@@ -270,7 +270,6 @@ class CoreAttention(MegatronModule):
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel: if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork(): with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
...@@ -477,6 +476,7 @@ class ParallelAttention(MegatronModule): ...@@ -477,6 +476,7 @@ class ParallelAttention(MegatronModule):
# ================================================= # =================================================
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
if inference_params: if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict: if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len inf_max_seq_len = inference_params.max_sequence_len
......
...@@ -244,6 +244,7 @@ def forward_backward_no_pipelining(forward_step_func, ...@@ -244,6 +244,7 @@ def forward_backward_no_pipelining(forward_step_func,
output_tensor = forward_step(forward_step_func, data_iterator, output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store, model, input_tensor, forward_data_store,
timers, collect_non_loss_data) timers, collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad, timers) output_tensor_grad, timers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment