"src/vscode:/vscode.git/clone" did not exist on "76c00c7236a4c8261947b5af5acdb086f5614576"
Commit aed2f75e authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into github-main

parents 8aa4619f f32a638d
......@@ -9,6 +9,16 @@ from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
......@@ -39,7 +49,7 @@ class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
seed, use_titles=True, use_one_sent_docs=False):
seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
......@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
query_mask = make_attention_mask(query_tokens, query_tokens)
context_mask = make_attention_mask(context_tokens, context_tokens)
block_data = sample_data.as_array()
sample = {
'query_tokens': query_tokens,
'query_mask': query_mask,
'query_pad_mask': query_pad_mask,
'block_tokens': block_tokens,
'block_pad_mask': block_pad_mask,
'context_tokens': context_tokens,
'context_mask': context_mask,
'context_pad_mask': context_pad_mask,
'block_data': block_data,
}
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wikipedia dataset from DPR code for ORQA."""
from abc import ABC
import csv
import numpy as np
import random
import torch
from torch.utils.data import Dataset
from megatron import print_rank_0, get_args, get_tokenizer, mpu
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_open_retrieval_wiki_dataset():
args = get_args()
tokenizer = get_tokenizer()
dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase',
'evidence',
args.evidence_data_path,
tokenizer,
args.retriever_seq_length)
return dataset
def get_open_retrieval_batch(data_iterator):
# Items and their type.
keys = ['row_id', 'context', 'context_mask', 'context_types',
'context_pad_mask']
datatype = torch.int64
# Broadcast data.
data = None if data_iterator is None else next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
row_id = data_b['row_id'].long()
context = data_b['context'].long()
# TODO: make the context mask a binary one
context_mask = (data_b['context_mask'] < 0.5)
context_types = data_b['context_types'].long()
context_pad_mask = data_b['context_pad_mask'].long()
return row_id, context, context_mask, context_types, context_pad_mask
def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
title_ids = tokenizer.tokenize(row['title'])
context_ids = tokenizer.tokenize(row['text'])
# Appending the title of the context at front
extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids
context_ids, context_types, context_pad_mask = \
build_tokens_types_paddings_from_ids(extended_context_ids,
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
return context_ids, context_types, context_pad_mask
# noinspection DuplicatedCode
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
cls_id, sep_id, pad_id):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(text_ids)
enc_ids.extend(text_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
pad_mask = np.array(pad_mask, dtype=np.int64)
return enc_ids, tokentypes_enc, pad_mask
def build_sample(row_id, context_ids, context_types, context_pad_mask):
"""Convert to numpy and return a sample consumed by the batch producer."""
context_ids = np.array(context_ids, dtype=np.int64)
context_types = np.array(context_types, dtype=np.int64)
context_mask = make_attention_mask(context_ids, context_ids)
sample = ({
'row_id': row_id,
'context': context_ids,
'context_mask': context_mask,
'context_types': context_types,
'context_pad_mask': context_pad_mask
})
return sample
class OpenRetrievalEvidenceDataset(ABC, Dataset):
"""Open Retrieval Evidence dataset class."""
def __init__(self, task_name, dataset_name, datapath, tokenizer,
max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
# Process the files.
print_rank_0(datapath)
self.samples, self.id2text = self.process_samples_from_single_path(
datapath)
args = get_args()
if args.sample_rate < 1: # subsample
k = int(len(self.samples) * args.sample_rate)
self.samples = random.sample(self.samples, k)
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
row = self.samples[idx]
context_ids, context_types, context_pad_mask = \
build_tokens_types_paddings_from_text(row, self.tokenizer,
self.max_seq_length)
sample = build_sample(row['doc_id'],
context_ids,
context_types,
context_pad_mask)
return sample
@staticmethod
def process_samples_from_single_path(filename):
print_rank_0(' > Processing {} ...'.format(filename))
total = 0
rows = []
id2text = {}
with open(filename) as tsvfile:
reader = csv.reader(tsvfile, delimiter='\t')
next(reader, None) # skip the headers
for row in reader:
# file format: doc_id, doc_text, title
doc_id = int(row[0])
text = row[1]
title = row[2]
rows.append({'doc_id': doc_id,
'text': text,
'title': title})
assert doc_id not in id2text
id2text[doc_id] = (text, title)
total += 1
if total % 100000 == 0:
print_rank_0(' > processed {} rows so far ...'.format(
total))
print_rank_0(' >> processed {} samples.'.format(len(rows)))
return rows, id2text
......@@ -14,34 +14,36 @@ def detach(tensor):
return tensor.detach().cpu().numpy()
class BlockData(object):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def __init__(self, block_data_path=None, load_from_path=True, rank=None):
class OpenRetreivalDataStore(object):
"""
Serializable data structure for holding data for blocks --
embeddings and necessary metadata for Retriever
"""
def __init__(self, embedding_path=None, load_from_path=True, rank=None):
self.embed_data = dict()
self.meta_data = dict()
if block_data_path is None:
if embedding_path is None:
args = get_args()
block_data_path = args.block_data_path
embedding_path = args.embedding_path
rank = args.rank
self.block_data_path = block_data_path
self.embedding_path = embedding_path
self.rank = rank
if load_from_path:
self.load_from_file()
block_data_name = os.path.splitext(self.block_data_path)[0]
block_data_name = os.path.splitext(self.embedding_path)[0]
self.temp_dir_name = block_data_name + '_tmp'
def state(self):
return {
'embed_data': self.embed_data,
'meta_data': self.meta_data,
}
def clear(self):
"""Clear the embedding data structures to save memory.
The metadata ends up getting used, and is also much smaller in dimensionality
so it isn't really worth clearing.
"""
Clear the embedding data structures to save memory.
The metadata ends up getting used, and is also much smaller in
dimensionality so it isn't really worth clearing.
"""
self.embed_data = dict()
......@@ -50,38 +52,39 @@ class BlockData(object):
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.block_data_path, 'rb'))
state_dict = pickle.load(open(self.embedding_path, 'rb'))
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data']
self.meta_data = state_dict['meta_data']
def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
"""Add data for set of blocks
:param block_indices: 1D array of unique int ids for the blocks
def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
"""
Add data for set of blocks
:param row_id: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks.
In the case of REALM this will be [start_idx, end_idx, doc_idx]
In the case of retriever this will be [start_idx, end_idx, doc_idx]
"""
for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
for idx, embed in zip(row_id, block_embeds):
if not allow_overwrite and idx in self.embed_data:
raise ValueError("Unexpectedly tried to overwrite block data")
self.embed_data[idx] = np.float16(embed)
self.meta_data[idx] = meta
def save_shard(self):
"""Save the block data that was created this in this process"""
"""
Save the block data that was created this in this process
"""
if not os.path.isdir(self.temp_dir_name):
os.makedirs(self.temp_dir_name, exist_ok=True)
# save the data for each shard
with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as data_file:
pickle.dump(self.state(), data_file)
with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \
as writer:
pickle.dump(self.state(), writer)
def merge_shards_and_save(self):
"""Combine all the shards made using self.save_shard()"""
#Combine all the shards made using save_shard
shard_names = os.listdir(self.temp_dir_name)
seen_own_shard = False
......@@ -96,15 +99,15 @@ class BlockData(object):
old_size = len(self.embed_data)
shard_size = len(data['embed_data'])
# add the shard's data and check to make sure there is no overlap
# add the shard's data and check to make sure there
# is no overlap
self.embed_data.update(data['embed_data'])
self.meta_data.update(data['meta_data'])
assert len(self.embed_data) == old_size + shard_size
assert seen_own_shard
# save the consolidated shards and remove temporary directory
with open(self.block_data_path, 'wb') as final_file:
with open(self.embedding_path, 'wb') as final_file:
pickle.dump(self.state(), final_file)
shutil.rmtree(self.temp_dir_name, ignore_errors=True)
......@@ -113,18 +116,22 @@ class BlockData(object):
class FaissMIPSIndex(object):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood"""
def __init__(self, embed_size, block_data=None, use_gpu=False):
"""
Wrapper object for a BlockData which similarity search via FAISS under the hood
"""
def __init__(self, embed_size, embed_data=None, use_gpu=False):
self.embed_size = embed_size
self.block_data = block_data
self.embed_data = embed_data
self.use_gpu = use_gpu
self.id_map = dict()
self.block_mips_index = None
self._set_block_index()
self.mips_index = None
self._set_mips_index()
def _set_block_index(self):
"""Create a Faiss Flat index with inner product as the metric to search against"""
def _set_mips_index(self):
"""
Create a Faiss Flat index with inner product as the metric
to search against
"""
try:
import faiss
except ImportError:
......@@ -132,85 +139,86 @@ class FaissMIPSIndex(object):
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
cpu_index = faiss.IndexFlatIP(self.embed_size)
if self.use_gpu:
# create resources and config for GpuIndex
res = faiss.StandardGpuResources()
config = faiss.GpuIndexFlatConfig()
config.device = torch.cuda.current_device()
config = faiss.GpuMultipleClonerOptions()
config.shard = True
config.useFloat16 = True
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
self.mips_index = faiss.IndexIDMap(gpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True)
print(">> Initialized index on GPU", flush=True)
else:
# CPU index supports IDs so wrap with IDMap
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
self.mips_index = faiss.IndexIDMap(cpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
if self.block_data is not None:
self.add_block_embed_data(self.block_data)
# if we were constructed with a BlockData, then automatically load it
# when the FAISS structure is built
if self.embed_data is not None:
self.add_embed_data(self.embed_data)
def reset_index(self):
"""Delete existing index and create anew"""
del self.block_mips_index
"""Delete existing index and create a new"""
del self.mips_index
# reset the block data so that _set_block_index will reload it as well
if self.block_data is not None:
block_data_path = self.block_data.block_data_path
del self.block_data
self.block_data = BlockData(block_data_path)
if self.embed_data is not None:
embed_data_path = self.embed_data.embedding_path
del self.embed_data
self.embed_data = OpenRetreivalDataStore(embed_data_path)
self._set_block_index()
self._set_mips_index()
def add_block_embed_data(self, all_block_data):
def update_index(self):
"""Delete existing index and create a new"""
del self.mips_index
# reset the block data so that _set_mips_index will reload it as well
if self.embed_data is not None:
self.embed_data.load_from_file()
self._set_mips_index()
def add_embed_data(self, all_embed_data):
"""Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>}
block_indices, block_embeds = zip(*all_block_data.embed_data.items())
block_indices, block_embeds = zip(*all_embed_data.embed_data.items())
# the embeddings have to be entered in as float32 even though the math internally is done with float16.
block_embeds_arr = np.float32(np.array(block_embeds))
block_indices_arr = np.array(block_indices)
# faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
if self.use_gpu:
for i, idx in enumerate(block_indices):
self.id_map[i] = idx
# the embeddings have to be entered in as float32 even though the math
# internally is done with float16.
embeds_arr = np.float32(np.array(block_embeds))
indices_arr = np.array(block_indices)
# we no longer need the embedding data since it's in the index now
all_block_data.clear()
all_embed_data.clear()
if self.use_gpu:
self.block_mips_index.add(block_embeds_arr)
else:
self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)
self.mips_index.add_with_ids(embeds_arr, indices_arr)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric.
"""
Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
if False: return [num_queries x k] array of distances, and another for indices
:param reconstruct: if True: return a [num_queries x k x embed_dim]
array of blocks
if False: return [num_queries x k] array of
distances, and another for indices
"""
query_embeds = np.float32(detach(query_embeds))
if reconstruct:
# get the vectors themselves
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
top_k_block_embeds = self.mips_index.search_and_reconstruct(\
query_embeds, top_k)
return top_k_block_embeds
else:
# get distances and indices of closest vectors
distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
if self.use_gpu:
fresh_indices = np.zeros(block_indices.shape)
for i, j in itertools.product(block_indices.shape):
fresh_indices[i, j] = self.id_map[block_indices[i, j]]
block_indices = fresh_indices
distances, block_indices = self.mips_index.search(query_embeds, top_k)
return distances, block_indices
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from torchvision import datasets, transforms
from megatron.data.autoaugment import ImageNetPolicy
def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
# training dataset
train_data_path = os.path.join(data_path[0], "train")
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
process = [
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
]
if color_jitter:
process += [
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
)
]
fp16_t = transforms.ConvertImageDtype(torch.half)
process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t]
transform_train = transforms.Compose(process)
train_data = datasets.ImageFolder(
root=train_data_path, transform=transform_train
)
# validation dataset
val_data_path = os.path.join(data_path[0], "val")
transform_val = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
fp16_t
]
)
val_data = datasets.ImageFolder(
root=val_data_path, transform=transform_val
)
return train_data, val_data
......@@ -13,114 +13,97 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import subprocess
import os
from torch.utils import cpp_extension
# Setting this param to a list has a problem of generating
# different compilation commands (with diferent order of architectures)
# and leading to recompilation of fused kernels.
# set it to empty string to avoid recompilation
# and assign arch flags explicity in extra_cuda_cflags below
# Setting this param to a list has a problem of generating different
# compilation commands (with diferent order of architectures) and
# leading to recompilation of fused kernels. Set it to empty string
# to avoid recompilation and assign arch flags explicity in
# extra_cuda_cflags below
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")
def load_scaled_upper_triang_masked_softmax_fusion_kernel():
def load(args):
# Check, if CUDA11 is installed for compute capability 8.0
# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
_, bare_metal_major, _ = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build'
create_build_dir(buildpath)
scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
name='scaled_upper_triang_masked_softmax_cuda',
_create_build_dir(buildpath)
# Helper function to build the kernels.
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'--use_fast_math'] + extra_cuda_flags + cc_flag,
verbose=(args.rank == 0)
)
# ==============
# Fused softmax.
# ==============
if args.masked_softmax_fusion:
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda']
# Upper triangular softmax.
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'],
build_directory=buildpath,
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + cc_flag)
def load_scaled_masked_softmax_fusion_kernel():
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_upper_triang_masked_softmax_cuda",
sources, extra_cuda_flags)
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build'
# Masked softmax.
sources=[srcpath / 'scaled_masked_softmax.cpp',
srcpath / 'scaled_masked_softmax_cuda.cu']
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_masked_softmax_cuda", sources, extra_cuda_flags)
create_build_dir(buildpath)
# =================================
# Mixed precision fused layer norm.
# =================================
scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
name='scaled_masked_softmax_cuda',
sources=[srcpath / 'scaled_masked_softmax.cpp',
srcpath / 'scaled_masked_softmax_cuda.cu'],
build_directory=buildpath,
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + cc_flag)
extra_cuda_flags = ['-maxrregcount=50']
sources=[srcpath / 'layer_norm_cuda.cpp',
srcpath / 'layer_norm_cuda_kernel.cu']
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
def load_fused_mix_prec_layer_norm_kernel():
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
return raw_output, bare_metal_major, bare_metal_minor
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build'
create_build_dir(buildpath)
fused_mix_prec_layer_norm_cuda = cpp_extension.load(
name='fused_mix_prec_layer_norm_cuda',
sources=[srcpath / 'layer_norm_cuda.cpp',
srcpath / 'layer_norm_cuda_kernel.cu'],
build_directory=buildpath,
extra_cflags=['-O3'],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-maxrregcount=50',
'--use_fast_math'] + cc_flag)
def _create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")
......@@ -24,16 +24,12 @@
#include "compat.h"
namespace {
void compute_n1_n2(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
int& n1,
int& n2)
{
int& n2) {
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
......@@ -47,11 +43,7 @@ void compute_n1_n2(
}
void check_args(
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
at::Tensor beta
)
......@@ -62,11 +54,7 @@ void check_args(
void check_args(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
int& n1,
int& n2
)
......@@ -102,11 +90,7 @@ void check_args(
void check_args(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
at::Tensor beta,
int& n1,
......@@ -125,60 +109,42 @@ void cuda_layer_norm(
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) {
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon);
return {output, mean, invvar};
}
std::vector<at::Tensor> layer_norm_affine(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor output = at::empty_like(input, input.options().dtype(at::ScalarType::Half));
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor output = at::empty_like(
input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty(
{n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon);
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
......@@ -186,11 +152,7 @@ void cuda_layer_norm_gradient(
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
......@@ -199,62 +161,41 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta
);
at::Tensor layer_norm_gradient(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor grad_input = at::empty_like(input);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon,
&grad_input,NULL,NULL);
return grad_input;
}
std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon,
&grad_input,&grad_gamma,&grad_beta);
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon,
&grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
m.def("forward", &layer_norm, "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");
m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)");
m.def("forward_affine", &layer_norm_affine,
"LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine,
"LayerNorm backward (CUDA)");
}
......@@ -285,15 +285,6 @@ struct SharedMemory <float>
}
};
template <>
struct SharedMemory <double>
{
__device__ double *getPointer()
{
extern __shared__ double s_double[];
return s_double;
}
};
}
template<typename T, typename U, typename V> __global__
......@@ -656,6 +647,9 @@ void cuComputeGradInput(
}
}
template<typename T, typename U, typename V>
void HostApplyLayerNorm(
V* output,
......@@ -671,7 +665,8 @@ void HostApplyLayerNorm(
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ?
......@@ -687,6 +682,7 @@ void HostApplyLayerNorm(
gamma,beta);
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
......@@ -704,21 +700,21 @@ void cuda_layer_norm(
double epsilon)
{
using namespace at;
DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
using output_t = at::Half;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
HostApplyLayerNorm(
output->DATA_PTR<output_t>(),
mean->DATA_PTR<accscalar_t>(),
invvar->DATA_PTR<accscalar_t>(),
input->DATA_PTR<scalar_t_0>(),
output->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(),
input->DATA_PTR<scalar_t_in>(),
n1,n2,
epsilon,
gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL,
beta != NULL ? beta->DATA_PTR<output_t>() : NULL);
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
)
}
template<typename T, typename U, typename V>
void HostLayerNormGradient(
const V* dout,
......@@ -742,10 +738,12 @@ void HostLayerNormGradient(
const int part_size = 16;
const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
(threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type()));
at::Tensor part_grad_gamma = at::empty(
{part_size,n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout,
......@@ -770,7 +768,8 @@ void HostLayerNormGradient(
}
// compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1);
int nshared =
......@@ -788,6 +787,7 @@ void HostLayerNormGradient(
grad_input);
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
......@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta)
{
using namespace at;
DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput",
using accscalar_t = at::acc_type<scalar_t_0, true>;
using output_t = at::Half;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma->scalar_type(),
"cuda_layer_norm_gradient_kernel",
HostLayerNormGradient(
dout->DATA_PTR<output_t>(),
mean->DATA_PTR<accscalar_t>(),
invvar->DATA_PTR<accscalar_t>(),
dout->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(),
input,
n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL,
gamma != NULL ? beta->DATA_PTR<output_t>() : NULL,
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
epsilon,
grad_input->DATA_PTR<scalar_t_0>(),
gamma != NULL ? grad_gamma->DATA_PTR<output_t>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<output_t>() : NULL);
grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
)
}
......@@ -37,8 +37,9 @@ torch::Tensor fwd(
torch::Tensor const& mask,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor);
......@@ -52,10 +53,12 @@ torch::Tensor bwd(
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
......
......@@ -19,10 +19,10 @@
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
......@@ -37,33 +37,39 @@ torch::Tensor fwd_cuda(
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int seq_len = input.size(2);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == seq_len);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, seq_len, seq_len}, act_options);
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
dispatch_scaled_masked_softmax_forward<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
seq_len,
seq_len,
batches,
attn_heads,
pad_batches);
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results;
}
......@@ -78,21 +84,25 @@ torch::Tensor bwd_cuda(
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int seq_len = output_grads.size(2);
TORCH_INTERNAL_ASSERT(output_grads.size(2) == output_grads.size(3));
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
dispatch_scaled_masked_softmax_backward<half, half, float>(
reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
batches,
attn_heads);
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads;
......
......@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda(
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
......@@ -47,10 +48,12 @@ torch::Tensor bwd(
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half,
"Only HALF is supported");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
......@@ -61,7 +64,7 @@ torch::Tensor bwd(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
......
......@@ -21,11 +21,47 @@
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
......@@ -73,7 +109,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst,
......@@ -89,10 +125,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE;
int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
......@@ -103,22 +140,36 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * stride + local_idx;
dst += first_batch * stride + local_idx;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
elements[i][it] = (acc_t)src[i*element_count*stride+it*WARP_SIZE] * scale;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if ((element_index + element) < batch_element_count) {
elements[i][it+element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else {
elements[i][it] = -std::numeric_limits<acc_t>::infinity();
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
......@@ -140,26 +191,37 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) {
if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < local_seq) {
dst[i*element_count*stride+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
} else {
out[element] = 0;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
} else if (element_index < element_count) {
dst[i*element_count*stride+it*WARP_SIZE] = 0;
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
} else {
break;
}
......@@ -183,6 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1;
......@@ -197,37 +260,41 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx;
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
output_reg[i][it] = output[i*element_count*stride+it*WARP_SIZE];
} else {
output_reg[i][it] = acc_t(0);
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = (acc_t)grad[i*element_count*stride+it*WARP_SIZE] * output_reg[i][it];
} else {
grad_reg[i][it] = acc_t(0);
}
}
}
acc_t sum[WARP_BATCH];
......@@ -247,11 +314,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
gradInput[i*element_count*stride+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i]));
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
}
}
}
......
......@@ -19,10 +19,10 @@
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
......@@ -46,15 +46,20 @@ torch::Tensor fwd_cuda(
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
dispatch_scaled_upper_triang_masked_softmax_forward<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches);
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
......@@ -72,14 +77,18 @@ torch::Tensor bwd_cuda(
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward<half, half, float>(
reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half*>(output_grads_ptr),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches);
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches);
);
//backward pass is completely in-place
return output_grads;
......
......@@ -14,214 +14,78 @@
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <ATen/ATen.h>
#include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i]));
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
......@@ -83,7 +83,8 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args)
_ = _build_tokenizer(args)
if args.vocab_file:
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers()
......@@ -131,12 +132,13 @@ def _set_tensorboard_writer(args):
'tensorboard writer')
if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == (args.world_size -1):
args.tensorboard_dir and args.rank == (args.world_size - 1):
try:
from torch.utils.tensorboard import SummaryWriter
print('> setting tensorboard ...')
_GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
log_dir=args.tensorboard_dir)
log_dir=args.tensorboard_dir,
max_queue=args.tensorboard_queue_size)
except ModuleNotFoundError:
print('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
......
import sys
import torch
import torch.distributed as dist
from megatron import get_args
from megatron import mpu
from megatron.checkpointing import load_ict_checkpoint
from megatron.data.ict_dataset import get_ict_dataset
from megatron.data.realm_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, BlockData
from megatron.data.realm_dataset_utils import get_ict_batch
from megatron.model.realm_model import general_ict_model_provider
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import get_model
class IndexBuilder(object):
"""Object for taking one pass over a dataset and creating a BlockData of its embeddings"""
"""
Object for taking one pass over a dataset and creating a BlockData of its
embeddings
"""
def __init__(self):
args = get_args()
self.model = None
self.dataloader = None
self.block_data = None
self.evidence_embedder_obj = None
self.biencoder_shared_query_context_model = \
args.biencoder_shared_query_context_model
# need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert not (args.load and args.ict_load)
self.using_realm_chkpt = args.ict_load is None
#self.using_realm_chkpt = args.ict_load is None
self.log_interval = args.indexer_log_interval
self.batch_size = args.indexer_batch_size
......@@ -33,59 +40,88 @@ class IndexBuilder(object):
self.iteration = self.total_processed = 0
def load_attributes(self):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
model = get_model(lambda: general_ict_model_provider(only_block_model=True))
self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
self.model.eval()
self.dataset = get_ict_dataset()
self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size))
self.block_data = BlockData(load_from_path=False)
"""
Load the necessary attributes: model, dataloader and empty BlockData
"""
only_context_model = True
if self.biencoder_shared_query_context_model:
only_context_model = False
model = get_model(lambda: biencoder_model_provider(only_context_model \
= only_context_model, biencoder_shared_query_context_model = \
self.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
only_context_model=only_context_model)
assert len(self.model) == 1
self.model[0].eval()
self.dataset = get_open_retrieval_wiki_dataset()
self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
self.batch_size))
self.evidence_embedder_obj = OpenRetreivalDataStore( \
load_from_path=False)
def track_and_report_progress(self, batch_size):
"""Utility function for tracking progress"""
"""
Utility function for tracking progress
"""
self.iteration += 1
self.total_processed += batch_size * self.num_total_builders
if self.is_main_builder and self.iteration % self.log_interval == 0:
print('Batch {:10d} | Total {:10d}'.format(self.iteration, self.total_processed), flush=True)
print('Batch {:10d} | Total {:10d}'.format(self.iteration,
self.total_processed), flush=True)
def build_and_save_index(self):
"""Goes through one epoch of the dataloader and adds all data to this instance's BlockData.
"""
Goes through one epoch of the dataloader and adds all data to this
instance's BlockData.
The copy of BlockData is saved as a shard, which when run in a distributed setting will be
consolidated by the rank 0 process and saved as a final pickled BlockData.
The copy of BlockData is saved as a shard, which when run in a
distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
"""
assert len(self.model) == 1
unwrapped_model = self.model[0]
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
while True:
try:
# batch also has query_tokens and query_pad_data
_, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader)
row_id, context_tokens, context_mask, context_types, \
context_pad_mask = get_open_retrieval_batch( \
self.dataloader)
except (StopIteration, IndexError):
break
unwrapped_model = self.model
while not hasattr(unwrapped_model, 'embed_block'):
unwrapped_model = unwrapped_model.module
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
block_logits = detach(unwrapped_model.embed_block(block_tokens, block_pad_mask))
detached_data = detach(block_sample_data)
# block_sample_data is a 2D array [batch x 4]
# with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
block_indices = detached_data[:, 3]
block_metas = detached_data[:, :3]
self.block_data.add_block_data(block_indices, block_logits, block_metas)
self.track_and_report_progress(batch_size=block_tokens.shape[0])
# This process signals to finalize its shard and then synchronize with the other processes
self.block_data.save_shard()
assert context_mask.dtype == torch.bool
context_logits = unwrapped_model.embed_text(
unwrapped_model.context_model, context_tokens, context_mask,
context_types)
context_logits = detach(context_logits)
row_id = detach(row_id)
self.evidence_embedder_obj.add_block_data(row_id, context_logits)
self.track_and_report_progress(batch_size=len(row_id))
# This process signals to finalize its shard and then synchronize with
# the other processes
self.evidence_embedder_obj.save_shard()
torch.distributed.barrier()
del self.model
# rank 0 process builds the final copy
if self.is_main_builder:
self.block_data.merge_shards_and_save()
self.evidence_embedder_obj.merge_shards_and_save()
# make sure that every single piece of data was embedded
assert len(self.block_data.embed_data) == len(self.dataset)
self.block_data.clear()
assert len(self.evidence_embedder_obj.embed_data) == \
len(self.dataset)
self.evidence_embedder_obj.clear()
# complete building the final copy
torch.distributed.barrier()
......@@ -17,16 +17,20 @@
import random
import os
import time
import numpy as np
import torch
from megatron import fused_kernels
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron.global_vars import set_global_variables
from megatron.mpu import set_tensor_model_parallel_rank, set_tensor_model_parallel_world_size
from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False):
......@@ -37,8 +41,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
what you are doing.
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
"""
"""
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
......@@ -66,7 +69,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager to call when it has DDP initialized
# and return function for external DDP manager
# to call when it has DDP initialized
set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init
else:
......@@ -79,19 +83,71 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Autoresume.
_init_autoresume()
# Compile dataset C++ code.
try:
from megatron.data import helpers
except:
if torch.distributed.get_rank() == 0:
from megatron.data.dataset_utils import compile_helper
compile_helper()
# Simple barrier
torch.distributed.barrier()
# Compile dependencies.
_compile_dependencies()
# No continuation function
return None
def _compile_dependencies():
args = get_args()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if torch.distributed.get_rank() == 0:
start_time = time.time()
print('> compiling dataset index builder ...')
from megatron.data.dataset_utils import compile_helper
compile_helper()
print('>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'.format(time.time() - start_time), flush=True)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len = args.seq_length
attn_batch_size = \
(args.num_attention_heads / args.tensor_model_parallel_size) * \
args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
# Print a warning.
if not ((args.fp16 or args.bf16) and
custom_kernel_constraint and
args.masked_softmax_fusion):
if args.rank == 0:
print('WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default'
' back to unfused kernel invocations.', flush=True)
# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
start_time = time.time()
print('> compiling and loading fused kernels ...', flush=True)
fused_kernels.load(args)
torch.distributed.barrier()
else:
torch.distributed.barrier()
fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>>> done with compiling and loading fused kernels. '
'Compilation time: {:.3f} seconds'.format(
time.time() - start_time), flush=True)
def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
......@@ -136,7 +192,8 @@ def _initialize_distributed():
print('model parallel is already initialized')
else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size)
def _init_autoresume():
......
......@@ -13,34 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
_LAYER_NORM = None
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
def import_layernorm(fp32_residual_connection):
global _LAYER_NORM
if not _LAYER_NORM:
if fp32_residual_connection:
from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm
else:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
_LAYER_NORM = LayerNorm
return _LAYER_NORM
from .distributed import *
from .bert_model import (BertModel,
BertModelFirstStage,
BertModelIntermediateStage,
BertModelLastStage)
from .realm_model import ICTBertModel
from .gpt_model import (GPTModel,
GPTModelFirstStage,
GPTModelIntermediateStage,
GPTModelLastStage)
from .distributed import DistributedDataParallel
from .bert_model import BertModel
from .gpt_model import GPTModel
from .language_model import get_language_model
from .module import FP16Module
from .realm_model import ICTBertModel
from .module import Float16Module
......@@ -19,19 +19,16 @@ import torch
from megatron import get_args
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model import import_layernorm
from megatron.model import LayerNorm
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
......@@ -77,13 +74,10 @@ class BertLMHead(MegatronModule):
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.tensor_model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = 1
mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
LayerNorm = import_layernorm(args.fp32_residual_connection)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
......@@ -127,31 +121,39 @@ def post_language_model_processing(lm_output, pooled_output,
return lm_loss, binary_logits
class BertModelBase(MegatronModule):
class BertModel(MegatronModule):
"""Bert Language model."""
def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True):
super(BertModelBase, self).__init__()
def __init__(self,
num_tokentypes=2,
add_binary_head=True,
parallel_output=True,
pre_process=True,
post_process=True):
super(BertModel, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal)
if mpu.is_pipeline_last_stage():
if self.post_process:
self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0),
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
......@@ -162,26 +164,30 @@ class BertModelBase(MegatronModule):
init_method)
self._binary_head_key = 'binary_head'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None):
extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
kwargs = {}
if mpu.is_pipeline_first_stage():
input_ids = bert_model_input
position_ids = bert_position_ids(input_ids)
args = [input_ids, position_ids, extended_attention_mask]
kwargs['tokentype_ids'] = tokentype_ids
else:
args = [bert_model_input, extended_attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage() and self.add_binary_head:
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids
)
if self.post_process and self.add_binary_head:
lm_output, pooled_output = lm_output
else:
pooled_output = None
if mpu.is_pipeline_last_stage():
if self.post_process:
return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head,
lm_labels,
......@@ -200,15 +206,15 @@ class BertModelBase(MegatronModule):
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage():
if self.post_process:
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage() and self.add_binary_head:
if self.post_process and self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars)
# Save word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
if self.post_process and not self.pre_process:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
return state_dict_
......@@ -218,74 +224,13 @@ class BertModelBase(MegatronModule):
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if mpu.is_pipeline_last_stage():
if self.post_process:
self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict)
if mpu.is_pipeline_last_stage() and self.add_binary_head:
if self.post_process and self.add_binary_head:
self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict)
# Load word_embeddings.
if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
if self.post_process and not self.pre_process:
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
class BertModel(BertModelBase):
def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True):
super(BertModel, self).__init__(
num_tokentypes=num_tokentypes,
add_binary_head=add_binary_head,
parallel_output=parallel_output)
def forward(self, input_ids, attention_mask,
tokentype_ids=None, lm_labels=None):
return super(BertModel, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids,
lm_labels=lm_labels)
class BertModelFirstStage(BertModelBase):
def __init__(self, num_tokentypes=2):
super(BertModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
return super(BertModelFirstStage, self).forward(
input_ids,
attention_mask,
tokentype_ids=tokentype_ids)
class BertModelIntermediateStage(BertModelBase):
def __init__(self, num_tokentypes=2):
super(BertModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask):
return super(BertModelIntermediateStage, self).forward(
hidden_state,
attention_mask)
class BertModelLastStage(BertModelBase):
def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True):
super(BertModelLastStage, self).__init__(
num_tokentypes=num_tokentypes,
add_binary_head=add_binary_head,
parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask,
lm_labels=None):
return super(BertModelLastStage, self).forward(
hidden_state,
attention_mask,
lm_labels=lm_labels)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment