Commit ee7b19e7 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merge branch 'main' into main_dedup

parents d413bd5f f2d64c00
#!/bin/bash
# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)
RANK=0
WORLD_SIZE=1
# Wikipedia data can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path to store embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>
python tools/create_doc_index.py \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--micro-batch-size 128 \
--checkpoint-activations \
--seq-length 512 \
--retriever-seq-length 256 \
--max-position-embeddings 512 \
--load ${CHECKPOINT_PATH} \
--evidence-data-path ${EVIDENCE_DATA_DIR} \
--embedding-path ${EMBEDDING_PATH} \
--indexer-log-interval 1000 \
--indexer-batch-size 128 \
--vocab-file bert-vocab.txt \
--num-workers 2 \
--fp16
#!/bin/bash
# Evaluate natural question test data given Wikipedia embeddings and pretrained
# ICT model
# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path of the embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>
QA_FILE=<Path of the natural question test dataset>
python tasks/main.py \
--task ICT-ZEROSHOT-NQ \
--tokenizer-type BertWordPieceLowerCase \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--micro-batch-size 128 \
--checkpoint-activations \
--seq-length 512 \
--max-position-embeddings 512 \
--load ${CHECKPOINT_PATH} \
--evidence-data-path ${EVIDENCE_DATA_DIR} \
--embedding-path ${EMBEDDING_PATH} \
--retriever-seq-length 256 \
--vocab-file bert-vocab.txt\
--qa-data-test ${QA_FILE} \
--num-workers 2 \
--faiss-use-gpu \
--retriever-report-topk-accuracies 1 5 20 100 \
--fp16
#! /bin/bash
# Runs the "217M" parameter biencoder model for ICT retriever
RANK=0
WORLD_SIZE=1
PRETRAINED_BERT_PATH=<Specify path of pretrained BERT model>
TEXT_DATA_PATH=<Specify path and file prefix of the text data>
TITLE_DATA_PATH=<Specify path and file prefix od the titles>
CHECKPOINT_PATH=<Specify path>
python pretrain_ict.py \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--micro-batch-size 32 \
--seq-length 256 \
--max-position-embeddings 512 \
--train-iters 100000 \
--vocab-file bert-vocab.txt \
--tokenizer-type BertWordPieceLowerCase \
--DDP-impl torch \
--bert-load ${PRETRAINED_BERT_PATH} \
--log-interval 100 \
--eval-interval 1000 \
--eval-iters 10 \
--retriever-report-topk-accuracies 1 5 10 20 100 \
--retriever-score-scaling \
--load $CHECKPOINT_PATH \
--save $CHECKPOINT_PATH \
--data-path ${TEXT_DATA_PATH} \
--titles-data-path ${TITLE_DATA_PATH} \
--lr 0.0001 \
--lr-decay-style linear \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--lr-warmup-fraction 0.01 \
--save-interval 4000 \
--exit-interval 8000 \
--query-in-block-prob 0.1 \
--fp16
...@@ -19,7 +19,6 @@ import argparse ...@@ -19,7 +19,6 @@ import argparse
import os import os
import torch import torch
from megatron import fused_kernels
def parse_args(extra_args_provider=None, defaults={}, def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False): ignore_unknown_args=False):
...@@ -39,7 +38,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -39,7 +38,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_validation_args(parser) parser = _add_validation_args(parser)
parser = _add_data_args(parser) parser = _add_data_args(parser)
parser = _add_autoresume_args(parser) parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser) parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser) parser = _add_vit_args(parser)
parser = _add_logging_args(parser) parser = _add_logging_args(parser)
...@@ -70,7 +69,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -70,7 +69,7 @@ def parse_args(extra_args_provider=None, defaults={},
model_parallel_size = args.pipeline_model_parallel_size * \ model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size args.tensor_model_parallel_size
assert args.world_size % model_parallel_size == 0, 'world size is not'\ assert args.world_size % model_parallel_size == 0, 'world size is not'\
' divisible by tensor parallel size ({}) times pipeline paralle ' \ ' divisible by tensor parallel size ({}) times pipeline parallel ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size, 'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size) args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // model_parallel_size args.data_parallel_size = args.world_size // model_parallel_size
...@@ -116,15 +115,38 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -116,15 +115,38 @@ def parse_args(extra_args_provider=None, defaults={},
print('setting global batch size to {}'.format( print('setting global batch size to {}'.format(
args.global_batch_size), flush=True) args.global_batch_size), flush=True)
assert args.global_batch_size > 0 assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.pipeline_model_parallel_size > 2, \
'pipeline-model-parallel size should be greater than 2 with ' \
'interleaved schedule'
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage'
args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage
else:
args.virtual_pipeline_model_parallel_size = None
# Parameters dtype. # Parameters dtype.
args.params_dtype = torch.float args.params_dtype = torch.float
if args.fp16: if args.fp16:
assert not args.bf16
args.params_dtype = torch.half args.params_dtype = torch.half
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
if args.rank == 0: if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype), print('using {} for parameters ...'.format(args.params_dtype),
flush=True) flush=True)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True
if args.dataloader_type is None: if args.dataloader_type is None:
args.dataloader_type = 'single' args.dataloader_type = 'single'
...@@ -195,39 +217,14 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -195,39 +217,14 @@ def parse_args(extra_args_provider=None, defaults={},
if args.fp16_lm_cross_entropy: if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection: if args.fp32_residual_connection:
assert args.fp16, \ assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing. # Activation checkpointing.
if args.distribute_checkpointed_activations: if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \ assert args.checkpoint_activations, \
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to enable checkpoint-activations'
# custom kernel constraints check
seq_len = args.seq_length
attn_batch_size = \
(args.num_attention_heads / args.tensor_model_parallel_size) * \
args.micro_batch_size
# constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
if args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion:
print('WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.')
# Load scaled_masked_softmax_fusion_kernels
if args.masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
fused_kernels.load_scaled_masked_softmax_fusion_kernel()
# Load mixed precision fused layer norm.
if args.fp32_residual_connection:
fused_kernels.load_fused_mix_prec_layer_norm_kernel()
_print_args(args) _print_args(args)
return args return args
...@@ -299,6 +296,8 @@ def _add_logging_args(parser): ...@@ -299,6 +296,8 @@ def _add_logging_args(parser):
group.add_argument('--log-params-norm', action='store_true', group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.') help='If set, calculate and log parameters norm.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--tensorboard-log-interval', type=int, default=1, group.add_argument('--tensorboard-log-interval', type=int, default=1,
help='Report to tensorboard interval.') help='Report to tensorboard interval.')
group.add_argument('--tensorboard-queue-size', type=int, default=1000, group.add_argument('--tensorboard-queue-size', type=int, default=1000,
...@@ -517,6 +516,8 @@ def _add_mixed_precision_args(parser): ...@@ -517,6 +516,8 @@ def _add_mixed_precision_args(parser):
group.add_argument('--fp16', action='store_true', group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.') help='Run model in fp16 mode.')
group.add_argument('--bf16', action='store_true',
help='Run model in bfloat16 mode.')
group.add_argument('--loss-scale', type=float, default=None, group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 ' help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic' 'values can improve fp16 convergence. If None, dynamic'
...@@ -538,8 +539,9 @@ def _add_mixed_precision_args(parser): ...@@ -538,8 +539,9 @@ def _add_mixed_precision_args(parser):
help='Run attention masking and softmax in fp32. ' help='Run attention masking and softmax in fp32. '
'This flag is ignored unless ' 'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.') '--no-query-key-layer-scaling is specified.')
group.add_argument('--fp32-allreduce', action='store_true', group.add_argument('--accumulate-allreduce-grads-in-fp32',
help='All-reduce in fp32') action='store_true',
help='Gradient accumulation and all-reduce in fp32.')
group.add_argument('--fp16-lm-cross-entropy', action='store_true', group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation' help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.') 'for lm head to fp16.')
...@@ -557,6 +559,8 @@ def _add_distributed_args(parser): ...@@ -557,6 +559,8 @@ def _add_distributed_args(parser):
group.add_argument('--model-parallel-size', type=int, default=None, group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use ' help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.') '--tensor-model-parallel-size instead.')
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
help='Number of layers per virtual pipeline stage')
group.add_argument('--distributed-backend', default='nccl', group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'], choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.') help='Which backend to use for distributed training.')
...@@ -564,6 +568,12 @@ def _add_distributed_args(parser): ...@@ -564,6 +568,12 @@ def _add_distributed_args(parser):
choices=['local', 'torch'], choices=['local', 'torch'],
help='which DistributedDataParallel implementation ' help='which DistributedDataParallel implementation '
'to use.') 'to use.')
group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true',
help='If set, use contiguous buffer in DDP. Note that '
'this option only works woth local DDP.' )
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline')
group.add_argument('--local_rank', type=int, default=None, group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher.') help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False, group.add_argument('--lazy-mpu-init', type=bool, required=False,
...@@ -615,6 +625,12 @@ def _add_data_args(parser): ...@@ -615,6 +625,12 @@ def _add_data_args(parser):
'This should be exclusive of --seq-length') 'This should be exclusive of --seq-length')
group.add_argument('--decoder-seq-length', type=int, default=None, group.add_argument('--decoder-seq-length', type=int, default=None,
help="Maximum decoder sequence length to process.") help="Maximum decoder sequence length to process.")
group.add_argument('--retriever-seq-length', type=int, default=256,
help='Maximum sequence length for the biencoder model '
' for retriever')
group.add_argument('--sample-rate', type=float, default=1.0,
help='sample rate for training data. Supposed to be 0 '
' < sample_rate < 1')
group.add_argument('--mask-prob', type=float, default=0.15, group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.') help='Probability of replacing a token with mask.')
group.add_argument('--short-seq-prob', type=float, default=0.1, group.add_argument('--short-seq-prob', type=float, default=0.1,
...@@ -655,13 +671,19 @@ def _add_autoresume_args(parser): ...@@ -655,13 +671,19 @@ def _add_autoresume_args(parser):
return parser return parser
def _add_realm_args(parser): def _add_biencoder_args(parser):
group = parser.add_argument_group(title='realm') group = parser.add_argument_group(title='biencoder')
# network size # network size
group.add_argument('--ict-head-size', type=int, default=None, group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and ' help='Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)') 'REALM (paper default: 128)')
group.add_argument('--biencoder-projection-dim', type=int, default=0,
help='Size of projection head used in biencoder (paper'
' default: 128)')
group.add_argument('--biencoder-shared-query-context-model', action='store_true',
help='Whether to share the parameters of the query '
'and context models or not')
# checkpointing # checkpointing
group.add_argument('--ict-load', type=str, default=None, group.add_argument('--ict-load', type=str, default=None,
...@@ -678,16 +700,23 @@ def _add_realm_args(parser): ...@@ -678,16 +700,23 @@ def _add_realm_args(parser):
'ICT dataset') 'ICT dataset')
group.add_argument('--use-one-sent-docs', action='store_true', group.add_argument('--use-one-sent-docs', action='store_true',
help='Whether to use one sentence documents in ICT') help='Whether to use one sentence documents in ICT')
group.add_argument('--evidence-data-path', type=str, default=None,
help='Path to Wikipedia Evidence frm DPR paper')
# training # training
group.add_argument('--report-topk-accuracies', nargs='+', default=[], group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
help="Which top-k accuracies to report (e.g. '1 5 20')") default=[], help="Which top-k accuracies to report "
"(e.g. '1 5 20')")
group.add_argument('--retriever-score-scaling', action='store_true',
help='Whether to scale retriever scores by inverse '
'square root of hidden size')
# faiss index # faiss index
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--block-data-path', type=str, default=None, group.add_argument('--block-data-path', type=str, default=None,
help='Where to save/load BlockData to/from') help='Where to save/load BlockData to/from')
group.add_argument('--embedding-path', type=str, default=None,
help='Where to save/load Open-Retrieval Embedding'
' data to/from')
# indexer # indexer
group.add_argument('--indexer-batch-size', type=int, default=128, group.add_argument('--indexer-batch-size', type=int, default=128,
......
...@@ -21,12 +21,12 @@ import sys ...@@ -21,12 +21,12 @@ import sys
import numpy as np import numpy as np
import torch import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from megatron import (get_args, from megatron import (get_args,
mpu, mpu,
print_rank_0, print_rank_0,
update_num_microbatches) update_num_microbatches,
utils)
_CHECKPOINT_VERSION = None _CHECKPOINT_VERSION = None
...@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args = get_args() args = get_args()
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
if isinstance(model, torchDDP): model = utils.unwrap_model(model)
model = model.module
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save)) iteration, args.save))
...@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['args'] = args state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0 state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint() if len(model) == 1:
state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
# Optimizer stuff. # Optimizer stuff.
if not args.no_save_optim: if not args.no_save_optim:
...@@ -202,6 +206,33 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model): ...@@ -202,6 +206,33 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model):
return t return t
def fix_query_key_value_ordering(model, checkpoint_version):
"""Fix up query/key/value matrix ordering if checkpoint
version is smaller than 2.0
"""
if checkpoint_version < 2.0:
for name, param in model.named_parameters():
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 3, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 3, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
if name.endswith(('.key_value.weight', '.key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 2, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 2, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True): def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration. """Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in strict (bool): whether to strictly enforce that the keys in
...@@ -211,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -211,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
args = get_args() args = get_args()
load_dir = getattr(args, load_arg) load_dir = getattr(args, load_arg)
if isinstance(model, torchDDP): model = utils.unwrap_model(model)
model = model.module
# Read the tracker file and set the iteration. # Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir) tracker_filename = get_checkpoint_tracker_filename(load_dir)
...@@ -297,30 +328,17 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -297,30 +328,17 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0('could not find arguments in the checkpoint ...') print_rank_0('could not find arguments in the checkpoint ...')
# Model. # Model.
model.load_state_dict(state_dict['model'], strict=strict) if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering # Fix up query/key/value matrix ordering if needed
if get_checkpoint_version() < 2.0: checkpoint_version = get_checkpoint_version()
checkpoint_version = get_checkpoint_version() print_rank_0(f' checkpoint version {checkpoint_version}')
for name, param in model.named_parameters(): fix_query_key_value_ordering(model, checkpoint_version)
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 3, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 3, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
if name.endswith(('.key_value.weight', '.key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 2, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 2, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
# Optimizer. # Optimizer.
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
...@@ -365,41 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -365,41 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
return iteration return iteration
def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False): def load_biencoder_checkpoint(model, only_query_model=False,
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" only_context_model=False, custom_load_path=None):
"""
selectively load retrieval models for indexing/retrieving
from saved checkpoints
"""
args = get_args() args = get_args()
if isinstance(model, torchDDP): model = utils.unwrap_model(model)
model = model.module
load_path = args.load if from_realm_chkpt else args.ict_load load_path = custom_load_path if custom_load_path is not None else args.load
tracker_filename = get_checkpoint_tracker_filename(load_path) tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f: with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip()) iteration = int(f.read().strip())
# assert iteration > 0
checkpoint_name = get_checkpoint_name(load_path, iteration, False) checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model'] ret_state_dict = state_dict['model']
if from_realm_chkpt and mpu.get_data_parallel_rank() == 0:
print(" loading ICT state dict from REALM", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model: if only_query_model:
ict_state_dict.pop('context_model') ret_state_dict.pop('context_model')
if only_block_model: if only_context_model:
ict_state_dict.pop('question_model') ret_state_dict.pop('query_model')
model.load_state_dict(ict_state_dict) assert len(model) == 1
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier() torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name)) print(' successfully loaded {}'.format(checkpoint_name))
return model return model
import os
import time
import numpy as np
import torch
from megatron import get_args, get_tokenizer, mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, \
pad_and_convert_to_numpy
from megatron.data.data_samplers import MegatronPretrainingSampler
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
args = get_args()
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
num_workers = args.num_workers
# Use megatron's sampler with consumed samples set to 0 as
# this is only for evaluation and don't intend to resume half way.
# Also, set the drop last to false as don't intend to remove
# the last batch
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=0,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size(),
drop_last=False)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ict_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_mask',
'context_tokens', 'context_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_mask = data_b['query_mask'] < 0.5
context_tokens = data_b['context_tokens'].long()
context_mask = data_b['context_mask'] < 0.5
block_indices = data_b['block_data'].long()
return query_tokens, query_mask,\
context_tokens, context_mask, block_indices
def join_str_list(str_list):
"""Join a list of strings, handling spaces appropriately"""
result = ""
for s in str_list:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
class BlockSampleData(object):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def __init__(self, start_idx, end_idx, doc_idx, block_idx):
self.start_idx = start_idx
self.end_idx = end_idx
self.doc_idx = doc_idx
self.block_idx = block_idx
def as_array(self):
return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
def as_tuple(self):
return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
class BlockSamplesMapping(object):
def __init__(self, mapping_array):
# make sure that the array is compatible with BlockSampleData
assert mapping_array.shape[1] == 4
self.mapping_array = mapping_array
def __len__(self):
return self.mapping_array.shape[0]
def __getitem__(self, idx):
"""Get the data associated with an indexed sample."""
sample_data = BlockSampleData(*self.mapping_array[idx])
return sample_data
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{}s'.format(seed)
if use_one_sent_docs:
indexmap_filename += '_1sentok'
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if mpu.get_data_parallel_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert block_dataset.doc_idx.dtype == np.int64
assert block_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
from megatron.data import helpers
mapping_array = helpers.build_blocks_mapping(
block_dataset.doc_idx,
block_dataset.sizes,
title_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length - 3, # account for added tokens
seed,
verbose,
use_one_sent_docs)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, mapping_array, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
mapping_array.shape[0]))
return samples_mapping
...@@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples):
class MegatronPretrainingSampler: class MegatronPretrainingSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size): data_parallel_rank, data_parallel_size, drop_last=True):
# Keep a copy of input params for later use. # Keep a copy of input params for later use.
self.total_samples = total_samples self.total_samples = total_samples
self.consumed_samples = consumed_samples self.consumed_samples = consumed_samples
...@@ -65,6 +65,7 @@ class MegatronPretrainingSampler: ...@@ -65,6 +65,7 @@ class MegatronPretrainingSampler:
self.data_parallel_rank = data_parallel_rank self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = \ self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
# Sanity checks. # Sanity checks.
assert self.total_samples > 0, \ assert self.total_samples > 0, \
...@@ -81,17 +82,26 @@ class MegatronPretrainingSampler: ...@@ -81,17 +82,26 @@ class MegatronPretrainingSampler:
def __len__(self): def __len__(self):
return self.total_samples return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx
def __iter__(self): def __iter__(self):
batch = [] batch = []
# Last batch if not complete will be dropped. # Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples): for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx) batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size: if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx = self.data_parallel_rank * self.micro_batch_size start_idx, end_idx = self.get_start_end_idx()
end_idx = start_idx + self.micro_batch_size
yield batch[start_idx:end_idx] yield batch[start_idx:end_idx]
batch = [] batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler: class MegatronPretrainingRandomSampler:
......
...@@ -9,6 +9,16 @@ from megatron import get_args ...@@ -9,6 +9,16 @@ from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping from megatron.data.realm_dataset_utils import get_block_samples_mapping
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_ict_dataset(use_titles=True, query_in_block_prob=1): def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
...@@ -39,7 +49,7 @@ class ICTDataset(Dataset): ...@@ -39,7 +49,7 @@ class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task.""" """Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix, def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, query_in_block_prob, num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
seed, use_titles=True, use_one_sent_docs=False): seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
self.name = name self.name = name
self.seed = seed self.seed = seed
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
...@@ -93,14 +103,20 @@ class ICTDataset(Dataset): ...@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
query_mask = make_attention_mask(query_tokens, query_tokens)
context_mask = make_attention_mask(context_tokens, context_tokens)
block_data = sample_data.as_array() block_data = sample_data.as_array()
sample = { sample = {
'query_tokens': query_tokens, 'query_tokens': query_tokens,
'query_mask': query_mask,
'query_pad_mask': query_pad_mask, 'query_pad_mask': query_pad_mask,
'block_tokens': block_tokens, 'context_tokens': context_tokens,
'block_pad_mask': block_pad_mask, 'context_mask': context_mask,
'context_pad_mask': context_pad_mask,
'block_data': block_data, 'block_data': block_data,
} }
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wikipedia dataset from DPR code for ORQA."""
from abc import ABC
import csv
import numpy as np
import random
import torch
from torch.utils.data import Dataset
from megatron import print_rank_0, get_args, get_tokenizer, mpu
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_open_retrieval_wiki_dataset():
args = get_args()
tokenizer = get_tokenizer()
dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase',
'evidence',
args.evidence_data_path,
tokenizer,
args.retriever_seq_length)
return dataset
def get_open_retrieval_batch(data_iterator):
# Items and their type.
keys = ['row_id', 'context', 'context_mask', 'context_types',
'context_pad_mask']
datatype = torch.int64
# Broadcast data.
data = None if data_iterator is None else next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
row_id = data_b['row_id'].long()
context = data_b['context'].long()
# TODO: make the context mask a binary one
context_mask = (data_b['context_mask'] < 0.5)
context_types = data_b['context_types'].long()
context_pad_mask = data_b['context_pad_mask'].long()
return row_id, context, context_mask, context_types, context_pad_mask
def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
title_ids = tokenizer.tokenize(row['title'])
context_ids = tokenizer.tokenize(row['text'])
# Appending the title of the context at front
extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids
context_ids, context_types, context_pad_mask = \
build_tokens_types_paddings_from_ids(extended_context_ids,
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
return context_ids, context_types, context_pad_mask
# noinspection DuplicatedCode
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
cls_id, sep_id, pad_id):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(text_ids)
enc_ids.extend(text_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
pad_mask = np.array(pad_mask, dtype=np.int64)
return enc_ids, tokentypes_enc, pad_mask
def build_sample(row_id, context_ids, context_types, context_pad_mask):
"""Convert to numpy and return a sample consumed by the batch producer."""
context_ids = np.array(context_ids, dtype=np.int64)
context_types = np.array(context_types, dtype=np.int64)
context_mask = make_attention_mask(context_ids, context_ids)
sample = ({
'row_id': row_id,
'context': context_ids,
'context_mask': context_mask,
'context_types': context_types,
'context_pad_mask': context_pad_mask
})
return sample
class OpenRetrievalEvidenceDataset(ABC, Dataset):
"""Open Retrieval Evidence dataset class."""
def __init__(self, task_name, dataset_name, datapath, tokenizer,
max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
# Process the files.
print_rank_0(datapath)
self.samples, self.id2text = self.process_samples_from_single_path(
datapath)
args = get_args()
if args.sample_rate < 1: # subsample
k = int(len(self.samples) * args.sample_rate)
self.samples = random.sample(self.samples, k)
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
row = self.samples[idx]
context_ids, context_types, context_pad_mask = \
build_tokens_types_paddings_from_text(row, self.tokenizer,
self.max_seq_length)
sample = build_sample(row['doc_id'],
context_ids,
context_types,
context_pad_mask)
return sample
@staticmethod
def process_samples_from_single_path(filename):
print_rank_0(' > Processing {} ...'.format(filename))
total = 0
rows = []
id2text = {}
with open(filename) as tsvfile:
reader = csv.reader(tsvfile, delimiter='\t')
next(reader, None) # skip the headers
for row in reader:
# file format: doc_id, doc_text, title
doc_id = int(row[0])
text = row[1]
title = row[2]
rows.append({'doc_id': doc_id,
'text': text,
'title': title})
assert doc_id not in id2text
id2text[doc_id] = (text, title)
total += 1
if total % 100000 == 0:
print_rank_0(' > processed {} rows so far ...'.format(
total))
print_rank_0(' >> processed {} samples.'.format(len(rows)))
return rows, id2text
...@@ -14,34 +14,36 @@ def detach(tensor): ...@@ -14,34 +14,36 @@ def detach(tensor):
return tensor.detach().cpu().numpy() return tensor.detach().cpu().numpy()
class BlockData(object): class OpenRetreivalDataStore(object):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM""" """
def __init__(self, block_data_path=None, load_from_path=True, rank=None): Serializable data structure for holding data for blocks --
embeddings and necessary metadata for Retriever
"""
def __init__(self, embedding_path=None, load_from_path=True, rank=None):
self.embed_data = dict() self.embed_data = dict()
self.meta_data = dict() if embedding_path is None:
if block_data_path is None:
args = get_args() args = get_args()
block_data_path = args.block_data_path embedding_path = args.embedding_path
rank = args.rank rank = args.rank
self.block_data_path = block_data_path self.embedding_path = embedding_path
self.rank = rank self.rank = rank
if load_from_path: if load_from_path:
self.load_from_file() self.load_from_file()
block_data_name = os.path.splitext(self.block_data_path)[0] block_data_name = os.path.splitext(self.embedding_path)[0]
self.temp_dir_name = block_data_name + '_tmp' self.temp_dir_name = block_data_name + '_tmp'
def state(self): def state(self):
return { return {
'embed_data': self.embed_data, 'embed_data': self.embed_data,
'meta_data': self.meta_data,
} }
def clear(self): def clear(self):
"""Clear the embedding data structures to save memory. """
The metadata ends up getting used, and is also much smaller in dimensionality Clear the embedding data structures to save memory.
so it isn't really worth clearing. The metadata ends up getting used, and is also much smaller in
dimensionality so it isn't really worth clearing.
""" """
self.embed_data = dict() self.embed_data = dict()
...@@ -50,38 +52,39 @@ class BlockData(object): ...@@ -50,38 +52,39 @@ class BlockData(object):
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True) print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.block_data_path, 'rb')) state_dict = pickle.load(open(self.embedding_path, 'rb'))
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True) print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data'] self.embed_data = state_dict['embed_data']
self.meta_data = state_dict['meta_data']
def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False): def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
"""Add data for set of blocks """
:param block_indices: 1D array of unique int ids for the blocks Add data for set of blocks
:param row_id: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks :param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks. In the case of retriever this will be [start_idx, end_idx, doc_idx]
In the case of REALM this will be [start_idx, end_idx, doc_idx]
""" """
for idx, embed, meta in zip(block_indices, block_embeds, block_metas): for idx, embed in zip(row_id, block_embeds):
if not allow_overwrite and idx in self.embed_data: if not allow_overwrite and idx in self.embed_data:
raise ValueError("Unexpectedly tried to overwrite block data") raise ValueError("Unexpectedly tried to overwrite block data")
self.embed_data[idx] = np.float16(embed) self.embed_data[idx] = np.float16(embed)
self.meta_data[idx] = meta
def save_shard(self): def save_shard(self):
"""Save the block data that was created this in this process""" """
Save the block data that was created this in this process
"""
if not os.path.isdir(self.temp_dir_name): if not os.path.isdir(self.temp_dir_name):
os.makedirs(self.temp_dir_name, exist_ok=True) os.makedirs(self.temp_dir_name, exist_ok=True)
# save the data for each shard # save the data for each shard
with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as data_file: with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \
pickle.dump(self.state(), data_file) as writer:
pickle.dump(self.state(), writer)
def merge_shards_and_save(self): def merge_shards_and_save(self):
"""Combine all the shards made using self.save_shard()""" #Combine all the shards made using save_shard
shard_names = os.listdir(self.temp_dir_name) shard_names = os.listdir(self.temp_dir_name)
seen_own_shard = False seen_own_shard = False
...@@ -96,15 +99,15 @@ class BlockData(object): ...@@ -96,15 +99,15 @@ class BlockData(object):
old_size = len(self.embed_data) old_size = len(self.embed_data)
shard_size = len(data['embed_data']) shard_size = len(data['embed_data'])
# add the shard's data and check to make sure there is no overlap # add the shard's data and check to make sure there
# is no overlap
self.embed_data.update(data['embed_data']) self.embed_data.update(data['embed_data'])
self.meta_data.update(data['meta_data'])
assert len(self.embed_data) == old_size + shard_size assert len(self.embed_data) == old_size + shard_size
assert seen_own_shard assert seen_own_shard
# save the consolidated shards and remove temporary directory # save the consolidated shards and remove temporary directory
with open(self.block_data_path, 'wb') as final_file: with open(self.embedding_path, 'wb') as final_file:
pickle.dump(self.state(), final_file) pickle.dump(self.state(), final_file)
shutil.rmtree(self.temp_dir_name, ignore_errors=True) shutil.rmtree(self.temp_dir_name, ignore_errors=True)
...@@ -113,18 +116,22 @@ class BlockData(object): ...@@ -113,18 +116,22 @@ class BlockData(object):
class FaissMIPSIndex(object): class FaissMIPSIndex(object):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood""" """
def __init__(self, embed_size, block_data=None, use_gpu=False): Wrapper object for a BlockData which similarity search via FAISS under the hood
"""
def __init__(self, embed_size, embed_data=None, use_gpu=False):
self.embed_size = embed_size self.embed_size = embed_size
self.block_data = block_data self.embed_data = embed_data
self.use_gpu = use_gpu self.use_gpu = use_gpu
self.id_map = dict()
self.block_mips_index = None self.mips_index = None
self._set_block_index() self._set_mips_index()
def _set_block_index(self): def _set_mips_index(self):
"""Create a Faiss Flat index with inner product as the metric to search against""" """
Create a Faiss Flat index with inner product as the metric
to search against
"""
try: try:
import faiss import faiss
except ImportError: except ImportError:
...@@ -132,85 +139,86 @@ class FaissMIPSIndex(object): ...@@ -132,85 +139,86 @@ class FaissMIPSIndex(object):
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Building index", flush=True) print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
cpu_index = faiss.IndexFlatIP(self.embed_size)
if self.use_gpu: if self.use_gpu:
# create resources and config for GpuIndex # create resources and config for GpuIndex
res = faiss.StandardGpuResources() config = faiss.GpuMultipleClonerOptions()
config = faiss.GpuIndexFlatConfig() config.shard = True
config.device = torch.cuda.current_device()
config.useFloat16 = True config.useFloat16 = True
gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) self.mips_index = faiss.IndexIDMap(gpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True) print(">> Initialized index on GPU", flush=True)
else: else:
# CPU index supports IDs so wrap with IDMap # CPU index supports IDs so wrap with IDMap
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) self.mips_index = faiss.IndexIDMap(cpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on CPU", flush=True) print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built # if we were constructed with a BlockData, then automatically load it
if self.block_data is not None: # when the FAISS structure is built
self.add_block_embed_data(self.block_data) if self.embed_data is not None:
self.add_embed_data(self.embed_data)
def reset_index(self): def reset_index(self):
"""Delete existing index and create anew""" """Delete existing index and create a new"""
del self.block_mips_index del self.mips_index
# reset the block data so that _set_block_index will reload it as well # reset the block data so that _set_block_index will reload it as well
if self.block_data is not None: if self.embed_data is not None:
block_data_path = self.block_data.block_data_path embed_data_path = self.embed_data.embedding_path
del self.block_data del self.embed_data
self.block_data = BlockData(block_data_path) self.embed_data = OpenRetreivalDataStore(embed_data_path)
self._set_block_index() self._set_mips_index()
def add_block_embed_data(self, all_block_data): def update_index(self):
"""Delete existing index and create a new"""
del self.mips_index
# reset the block data so that _set_mips_index will reload it as well
if self.embed_data is not None:
self.embed_data.load_from_file()
self._set_mips_index()
def add_embed_data(self, all_embed_data):
"""Add the embedding of each block to the underlying FAISS index""" """Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>} # this assumes the embed_data is a dict : {int: np.array<float>}
block_indices, block_embeds = zip(*all_block_data.embed_data.items()) block_indices, block_embeds = zip(*all_embed_data.embed_data.items())
# the embeddings have to be entered in as float32 even though the math internally is done with float16. # the embeddings have to be entered in as float32 even though the math
block_embeds_arr = np.float32(np.array(block_embeds)) # internally is done with float16.
block_indices_arr = np.array(block_indices) embeds_arr = np.float32(np.array(block_embeds))
indices_arr = np.array(block_indices)
# faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
if self.use_gpu:
for i, idx in enumerate(block_indices):
self.id_map[i] = idx
# we no longer need the embedding data since it's in the index now # we no longer need the embedding data since it's in the index now
all_block_data.clear() all_embed_data.clear()
if self.use_gpu: self.mips_index.add_with_ids(embeds_arr, indices_arr)
self.block_mips_index.add(block_embeds_arr)
else:
self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">>> Finished adding block data to index", flush=True) print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True): def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric. """
Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks :param reconstruct: if True: return a [num_queries x k x embed_dim]
if False: return [num_queries x k] array of distances, and another for indices array of blocks
if False: return [num_queries x k] array of
distances, and another for indices
""" """
query_embeds = np.float32(detach(query_embeds)) query_embeds = np.float32(detach(query_embeds))
if reconstruct: if reconstruct:
# get the vectors themselves # get the vectors themselves
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k) top_k_block_embeds = self.mips_index.search_and_reconstruct(\
query_embeds, top_k)
return top_k_block_embeds return top_k_block_embeds
else: else:
# get distances and indices of closest vectors # get distances and indices of closest vectors
distances, block_indices = self.block_mips_index.search(query_embeds, top_k) distances, block_indices = self.mips_index.search(query_embeds, top_k)
if self.use_gpu:
fresh_indices = np.zeros(block_indices.shape)
for i, j in itertools.product(block_indices.shape):
fresh_indices[i, j] = self.id_map[block_indices[i, j]]
block_indices = fresh_indices
return distances, block_indices return distances, block_indices
...@@ -13,114 +13,97 @@ ...@@ -13,114 +13,97 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import pathlib import pathlib
import subprocess import subprocess
import os
from torch.utils import cpp_extension from torch.utils import cpp_extension
# Setting this param to a list has a problem of generating # Setting this param to a list has a problem of generating different
# different compilation commands (with diferent order of architectures) # compilation commands (with diferent order of architectures) and
# and leading to recompilation of fused kernels. # leading to recompilation of fused kernels. Set it to empty string
# set it to empty string to avoid recompilation # to avoid recompilation and assign arch flags explicity in
# and assign arch flags explicity in extra_cuda_cflags below # extra_cuda_cflags below
os.environ["TORCH_CUDA_ARCH_LIST"] = "" os.environ["TORCH_CUDA_ARCH_LIST"] = ""
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")
def load_scaled_upper_triang_masked_softmax_fusion_kernel(): def load(args):
# Check, if CUDA11 is installed for compute capability 8.0 # Check if cuda 11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag.append('-gencode') cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('arch=compute_80,code=sm_80')
# Build path
srcpath = pathlib.Path(__file__).parent.absolute() srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build' buildpath = srcpath / 'build'
_create_build_dir(buildpath)
create_build_dir(buildpath)
# Helper function to build the kernels.
scaled_upper_triang_masked_softmax_cuda = cpp_extension.load( def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
name='scaled_upper_triang_masked_softmax_cuda', return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'--use_fast_math'] + extra_cuda_flags + cc_flag,
verbose=(args.rank == 0)
)
# ==============
# Fused softmax.
# ==============
if args.masked_softmax_fusion:
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda']
# Upper triangular softmax.
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'], srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
build_directory=buildpath, scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
extra_cflags=['-O3',], "scaled_upper_triang_masked_softmax_cuda",
extra_cuda_cflags=['-O3', sources, extra_cuda_flags)
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + cc_flag)
def load_scaled_masked_softmax_fusion_kernel():
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
srcpath = pathlib.Path(__file__).parent.absolute() # Masked softmax.
buildpath = srcpath / 'build' sources=[srcpath / 'scaled_masked_softmax.cpp',
srcpath / 'scaled_masked_softmax_cuda.cu']
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_masked_softmax_cuda", sources, extra_cuda_flags)
create_build_dir(buildpath) # =================================
# Mixed precision fused layer norm.
# =================================
scaled_upper_triang_masked_softmax_cuda = cpp_extension.load( extra_cuda_flags = ['-maxrregcount=50']
name='scaled_masked_softmax_cuda', sources=[srcpath / 'layer_norm_cuda.cpp',
sources=[srcpath / 'scaled_masked_softmax.cpp', srcpath / 'layer_norm_cuda_kernel.cu']
srcpath / 'scaled_masked_softmax_cuda.cu'], fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
build_directory=buildpath, "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + cc_flag)
def load_fused_mix_prec_layer_norm_kernel(): def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
# Check, if CUDA11 is installed for compute capability 8.0 return raw_output, bare_metal_major, bare_metal_minor
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build'
create_build_dir(buildpath) def _create_build_dir(buildpath):
try:
fused_mix_prec_layer_norm_cuda = cpp_extension.load( os.mkdir(buildpath)
name='fused_mix_prec_layer_norm_cuda', except OSError:
sources=[srcpath / 'layer_norm_cuda.cpp', if not os.path.isdir(buildpath):
srcpath / 'layer_norm_cuda_kernel.cu'], print(f"Creation of the build directory {buildpath} failed")
build_directory=buildpath,
extra_cflags=['-O3'],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-maxrregcount=50',
'--use_fast_math'] + cc_flag)
...@@ -24,16 +24,12 @@ ...@@ -24,16 +24,12 @@
#include "compat.h" #include "compat.h"
namespace { namespace {
void compute_n1_n2( void compute_n1_n2(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
int& n1, int& n1,
int& n2) int& n2) {
{
int idiff = input.ndimension() - normalized_shape.size(); int idiff = input.ndimension() - normalized_shape.size();
n2 = 1; n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) { for (int i = 0; i < (int)normalized_shape.size(); ++i) {
...@@ -47,11 +43,7 @@ void compute_n1_n2( ...@@ -47,11 +43,7 @@ void compute_n1_n2(
} }
void check_args( void check_args(
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta at::Tensor beta
) )
...@@ -62,11 +54,7 @@ void check_args( ...@@ -62,11 +54,7 @@ void check_args(
void check_args( void check_args(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
int& n1, int& n1,
int& n2 int& n2
) )
...@@ -102,11 +90,7 @@ void check_args( ...@@ -102,11 +90,7 @@ void check_args(
void check_args( void check_args(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta, at::Tensor beta,
int& n1, int& n1,
...@@ -125,60 +109,42 @@ void cuda_layer_norm( ...@@ -125,60 +109,42 @@ void cuda_layer_norm(
at::Tensor* input, at::Tensor* input,
int n1, int n1,
int n2, int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma, at::Tensor* gamma,
at::Tensor* beta, at::Tensor* beta,
double epsilon); double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) {
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon);
return {output, mean, invvar};
}
std::vector<at::Tensor> layer_norm_affine( std::vector<at::Tensor> layer_norm_affine(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta, at::Tensor beta,
double epsilon) { double epsilon) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(gamma); CHECK_INPUT(gamma);
CHECK_INPUT(beta); CHECK_INPUT(beta);
int n1,n2; int n1, n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor output = at::empty_like(input, input.options().dtype(at::ScalarType::Half));
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor output = at::empty_like(
input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty(
{n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon); cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon);
return {output, mean, invvar}; return {output, mean, invvar};
} }
void cuda_layer_norm_gradient( void cuda_layer_norm_gradient(
at::Tensor* dout, at::Tensor* dout,
at::Tensor* mean, at::Tensor* mean,
...@@ -186,11 +152,7 @@ void cuda_layer_norm_gradient( ...@@ -186,11 +152,7 @@ void cuda_layer_norm_gradient(
at::Tensor* input, at::Tensor* input,
int n1, int n1,
int n2, int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma, at::Tensor* gamma,
at::Tensor* beta, at::Tensor* beta,
double epsilon, double epsilon,
...@@ -199,62 +161,41 @@ void cuda_layer_norm_gradient( ...@@ -199,62 +161,41 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta at::Tensor* grad_beta
); );
at::Tensor layer_norm_gradient(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
int n1,n2;
check_args(input,normalized_shape,n1,n2);
at::Tensor grad_input = at::empty_like(input);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon,
&grad_input,NULL,NULL);
return grad_input;
}
std::vector<at::Tensor> layer_norm_gradient_affine( std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout, at::Tensor dout,
at::Tensor mean, at::Tensor mean,
at::Tensor invvar, at::Tensor invvar,
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta, at::Tensor beta,
double epsilon) { double epsilon) {
CHECK_INPUT(dout); CHECK_INPUT(dout);
CHECK_INPUT(mean); CHECK_INPUT(mean);
CHECK_INPUT(invvar); CHECK_INPUT(invvar);
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(gamma); CHECK_INPUT(gamma);
CHECK_INPUT(beta); CHECK_INPUT(beta);
int n1,n2; int n1, n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor grad_input = at::empty_like(input); at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma); at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta); at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon, cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
&grad_input,&grad_gamma,&grad_beta); normalized_shape, &gamma, &beta, epsilon,
&grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta}; return {grad_input, grad_gamma, grad_beta};
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("forward_affine", &layer_norm_affine,
m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine,
m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); "LayerNorm backward (CUDA)");
} }
...@@ -285,15 +285,6 @@ struct SharedMemory <float> ...@@ -285,15 +285,6 @@ struct SharedMemory <float>
} }
}; };
template <>
struct SharedMemory <double>
{
__device__ double *getPointer()
{
extern __shared__ double s_double[];
return s_double;
}
};
} }
template<typename T, typename U, typename V> __global__ template<typename T, typename U, typename V> __global__
...@@ -656,6 +647,9 @@ void cuComputeGradInput( ...@@ -656,6 +647,9 @@ void cuComputeGradInput(
} }
} }
template<typename T, typename U, typename V> template<typename T, typename U, typename V>
void HostApplyLayerNorm( void HostApplyLayerNorm(
V* output, V* output,
...@@ -671,7 +665,8 @@ void HostApplyLayerNorm( ...@@ -671,7 +665,8 @@ void HostApplyLayerNorm(
{ {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1); const dim3 threads(32,4,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared = int nshared =
threads.y > 1 ? threads.y > 1 ?
...@@ -687,6 +682,7 @@ void HostApplyLayerNorm( ...@@ -687,6 +682,7 @@ void HostApplyLayerNorm(
gamma,beta); gamma,beta);
} }
void cuda_layer_norm( void cuda_layer_norm(
at::Tensor* output, at::Tensor* output,
at::Tensor* mean, at::Tensor* mean,
...@@ -704,21 +700,21 @@ void cuda_layer_norm( ...@@ -704,21 +700,21 @@ void cuda_layer_norm(
double epsilon) double epsilon)
{ {
using namespace at; using namespace at;
DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel", DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
using accscalar_t = at::acc_type<scalar_t_0, true>; input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
using output_t = at::Half;
HostApplyLayerNorm( HostApplyLayerNorm(
output->DATA_PTR<output_t>(), output->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<accscalar_t>(), mean->DATA_PTR<float>(),
invvar->DATA_PTR<accscalar_t>(), invvar->DATA_PTR<float>(),
input->DATA_PTR<scalar_t_0>(), input->DATA_PTR<scalar_t_in>(),
n1,n2, n1,n2,
epsilon, epsilon,
gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
beta != NULL ? beta->DATA_PTR<output_t>() : NULL); beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
) )
} }
template<typename T, typename U, typename V> template<typename T, typename U, typename V>
void HostLayerNormGradient( void HostLayerNormGradient(
const V* dout, const V* dout,
...@@ -742,10 +738,12 @@ void HostLayerNormGradient( ...@@ -742,10 +738,12 @@ void HostLayerNormGradient(
const int part_size = 16; const int part_size = 16;
const dim3 threads2(32,4,1); const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
(threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type())); at::Tensor part_grad_gamma = at::empty(
{part_size,n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, dout,
...@@ -770,7 +768,8 @@ void HostLayerNormGradient( ...@@ -770,7 +768,8 @@ void HostLayerNormGradient(
} }
// compute grad_input // compute grad_input
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1); const dim3 threads1(32,4,1);
int nshared = int nshared =
...@@ -788,6 +787,7 @@ void HostLayerNormGradient( ...@@ -788,6 +787,7 @@ void HostLayerNormGradient(
grad_input); grad_input);
} }
void cuda_layer_norm_gradient( void cuda_layer_norm_gradient(
at::Tensor* dout, at::Tensor* dout,
at::Tensor* mean, at::Tensor* mean,
...@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient( ...@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta) at::Tensor* grad_beta)
{ {
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput", DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
using accscalar_t = at::acc_type<scalar_t_0, true>; input->scalar_type(), gamma->scalar_type(),
using output_t = at::Half; "cuda_layer_norm_gradient_kernel",
HostLayerNormGradient( HostLayerNormGradient(
dout->DATA_PTR<output_t>(), dout->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<accscalar_t>(), mean->DATA_PTR<float>(),
invvar->DATA_PTR<accscalar_t>(), invvar->DATA_PTR<float>(),
input, input,
n1,n2, n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input. // if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<output_t>() : NULL, gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
epsilon, epsilon,
grad_input->DATA_PTR<scalar_t_0>(), grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<output_t>() : NULL, gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<output_t>() : NULL); gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
) )
} }
...@@ -37,8 +37,9 @@ torch::Tensor fwd( ...@@ -37,8 +37,9 @@ torch::Tensor fwd(
torch::Tensor const& mask, torch::Tensor const& mask,
float scale_factor) { float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half, AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
"Only HALF is supported"); (input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor); return fwd_cuda(input, mask, scale_factor);
...@@ -52,10 +53,12 @@ torch::Tensor bwd( ...@@ -52,10 +53,12 @@ torch::Tensor bwd(
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half, AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
"Only HALF is supported"); (output_grads.scalar_type() == at::ScalarType::BFloat16),
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half, "Only fp16 and bf16 are supported");
"Only HALF is supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor); return bwd_cuda(output_grads, softmax_results, scale_factor);
} }
......
...@@ -26,6 +26,27 @@ ...@@ -26,6 +26,27 @@
namespace { namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
int log2_ceil(int value) { int log2_ceil(int value) {
int log2_value = 0; int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value; while ((1 << log2_value) < value) ++log2_value;
...@@ -90,13 +111,14 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -90,13 +111,14 @@ __global__ void scaled_masked_softmax_warp_forward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches) // gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
int pad_first_batch = 0; int pad_first_batch = 0;
if (pad_batches != 1) { // bert style if (pad_batches != 1) { // bert style
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
} else { // gpt2 style } else { // gpt2 style
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
} }
...@@ -110,29 +132,40 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -110,29 +132,40 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch // there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
src += first_batch * element_count + local_idx; src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + local_idx; dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
mask += pad_first_batch * element_count + local_idx; mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory // load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS]; acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count; int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
int itr_idx = i*element_count+it*WARP_SIZE;
if (element_index < batch_element_count) { if (element_index < batch_element_count) {
if (mask[itr_idx] != 1) { int itr_idx = i*element_count+it*WARP_SIZE;
elements[i][it] = (acc_t)src[itr_idx] * scale; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
} else { copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
elements[i][it] = -10000.0;
} #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -10000.0;
}
}
} else { } else {
elements[i][it] = -std::numeric_limits<acc_t>::infinity(); #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
} }
} }
} }
...@@ -161,15 +194,20 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -161,15 +194,20 @@ __global__ void scaled_masked_softmax_warp_forward(
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum); warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result // store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) if (i >= local_batches)
break; break;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) { if (element_index < element_count) {
dst[i*element_count+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]); #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else { } else {
break; break;
} }
...@@ -192,6 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -192,6 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches) // gridDim/blockIdx = (seq_len, attn_heads, batches)
...@@ -207,36 +246,36 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -207,36 +246,36 @@ __global__ void scaled_masked_softmax_warp_backward(
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
// the first element to process by the current thread // the first element to process by the current thread
int thread_offset = first_batch * element_count + local_idx; int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset; grad += thread_offset;
output += thread_offset; output += thread_offset;
gradInput += thread_offset; gradInput += thread_offset;
// load data from global memory // load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count; int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) { if (element_index < batch_element_count) {
output_reg[i][it] = output[i*element_count+it*WARP_SIZE]; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
} else { copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
output_reg[i][it] = acc_t(0);
} #pragma unroll
} for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
output_reg[i][it + element] = (acc_t)temp_output[element];
#pragma unroll }
for (int it = 0; it < WARP_ITERATIONS; ++it) { #pragma unroll
int element_index = local_idx + it * WARP_SIZE; for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index < batch_element_count) { grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
grad_reg[i][it] = (acc_t)grad[i*element_count+it*WARP_SIZE] * output_reg[i][it]; }
} else { }
grad_reg[i][it] = acc_t(0);
}
} }
} }
...@@ -257,11 +296,16 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -257,11 +296,16 @@ __global__ void scaled_masked_softmax_warp_backward(
if (i >= local_batches) if (i >= local_batches)
break; break;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) { if (element_index < element_count) {
// compute gradients // compute gradients
gradInput[i*element_count+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i])); output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
} }
} }
} }
...@@ -299,8 +343,8 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -299,8 +343,8 @@ void dispatch_scaled_masked_softmax_forward(
constexpr int threads_per_block = 128; constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR // Launch code would be more elegant if C++ supported FOR CONSTEXPR
...@@ -388,7 +432,7 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -388,7 +432,7 @@ void dispatch_scaled_masked_softmax_backward(
constexpr int threads_per_block = 128; constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
int blocks = batch_count/batches_per_block; int blocks = batch_count/batches_per_block;
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR // Launch code would be more elegant if C++ supported FOR CONSTEXPR
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_masked_softmax.h" #include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
...@@ -56,16 +56,20 @@ torch::Tensor fwd_cuda( ...@@ -56,16 +56,20 @@ torch::Tensor fwd_cuda(
void* mask_ptr = static_cast<void*>(mask.data_ptr()); void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
dispatch_scaled_masked_softmax_forward<half, half, float>( DISPATCH_HALF_AND_BFLOAT(
reinterpret_cast<half*>(softmax_results_ptr), input.scalar_type(),
reinterpret_cast<const half*>(input_ptr), "dispatch_scaled_masked_softmax_forward",
reinterpret_cast<const uint8_t*>(mask_ptr), dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
scale_factor, reinterpret_cast<scalar_t*>(softmax_results_ptr),
query_seq_len, reinterpret_cast<const scalar_t*>(input_ptr),
key_seq_len, reinterpret_cast<const uint8_t*>(mask_ptr),
batches, scale_factor,
attn_heads, query_seq_len,
pad_batches); key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results; return softmax_results;
} }
...@@ -86,15 +90,19 @@ torch::Tensor bwd_cuda( ...@@ -86,15 +90,19 @@ torch::Tensor bwd_cuda(
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad //Softmax Grad
dispatch_scaled_masked_softmax_backward<half, half, float>( DISPATCH_HALF_AND_BFLOAT(
reinterpret_cast<half*>(output_grads_ptr), output_grads_.scalar_type(),
reinterpret_cast<half*>(output_grads_ptr), "dispatch_scaled_masked_softmax_backward",
reinterpret_cast<half const*>(softmax_results.data_ptr()), dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
scale_factor, reinterpret_cast<scalar_t*>(output_grads_ptr),
query_seq_len, reinterpret_cast<scalar_t*>(output_grads_ptr),
key_seq_len, reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
batches, scale_factor,
attn_heads); query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place //backward pass is completely in-place
return output_grads; return output_grads;
......
...@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda( ...@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda(
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.scalar_type() == at::ScalarType::Half, AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
"Only HALF is supported"); (input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor); return fwd_cuda(input, scale_factor);
} }
...@@ -47,10 +48,12 @@ torch::Tensor bwd( ...@@ -47,10 +48,12 @@ torch::Tensor bwd(
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.scalar_type() == at::ScalarType::Half, AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
"Only HALF is supported"); (output_grads.scalar_type() == at::ScalarType::BFloat16),
AT_ASSERTM(softmax_results.scalar_type() == at::ScalarType::Half, "Only fp16 and bf16 are supported");
"Only HALF is supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor); return bwd_cuda(output_grads, softmax_results, scale_factor);
} }
...@@ -61,7 +64,7 @@ torch::Tensor bwd( ...@@ -61,7 +64,7 @@ torch::Tensor bwd(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward."); "Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward", m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
......
...@@ -21,11 +21,47 @@ ...@@ -21,11 +21,47 @@
#include <cfloat> #include <cfloat>
#include <limits> #include <limits>
#include <stdint.h> #include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
namespace { namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
int log2_ceil(int value) { int log2_ceil(int value) {
int log2_value = 0; int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value; while ((1 << log2_value) < value) ++log2_value;
...@@ -73,7 +109,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { ...@@ -73,7 +109,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
* Extended softmax (from native aten pytorch) with following additional features * Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling * 1) input scaling
* 2) Implicit time (diagonal masking) * 2) Implicit time (diagonal masking)
*/ */
template <typename input_t, typename output_t, typename acc_t, int log2_elements> template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward( __global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst, output_t *dst,
...@@ -89,10 +125,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -89,10 +125,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE; int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
...@@ -103,22 +140,36 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -103,22 +140,36 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch // there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
src += first_batch * stride + local_idx; src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + local_idx; dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory // load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS]; acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq; int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) { if (element_index < batch_element_count) {
elements[i][it] = (acc_t)src[i*element_count*stride+it*WARP_SIZE] * scale; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if ((element_index + element) < batch_element_count) {
elements[i][it+element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else { } else {
elements[i][it] = -std::numeric_limits<acc_t>::infinity(); #pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
} }
} }
} }
...@@ -140,26 +191,37 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -140,26 +191,37 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) { if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i])); elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it]; sum[i] += elements[i][it];
} }
} }
} }
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum); warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result // store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) if (i >= local_batches)
break; break;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < local_seq) { if (element_index < local_seq) {
dst[i*element_count*stride+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
} else {
out[element] = 0;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
} else if (element_index < element_count) { } else if (element_index < element_count) {
dst[i*element_count*stride+it*WARP_SIZE] = 0; copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
} else { } else {
break; break;
} }
...@@ -183,6 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -183,6 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
...@@ -197,37 +260,41 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -197,37 +260,41 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
// the first element to process by the current thread // the first element to process by the current thread
int thread_offset = first_batch * stride + local_idx; int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset; grad += thread_offset;
output += thread_offset; output += thread_offset;
gradInput += thread_offset; gradInput += thread_offset;
// load data from global memory // load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq; int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) { if (element_index < batch_element_count) {
output_reg[i][it] = output[i*element_count*stride+it*WARP_SIZE]; copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
} else { copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
output_reg[i][it] = acc_t(0);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
} }
} }
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
grad_reg[i][it] = (acc_t)grad[i*element_count*stride+it*WARP_SIZE] * output_reg[i][it];
} else {
grad_reg[i][it] = acc_t(0);
}
}
} }
acc_t sum[WARP_BATCH]; acc_t sum[WARP_BATCH];
...@@ -247,11 +314,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -247,11 +314,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
if (i >= local_batches) if (i >= local_batches)
break; break;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = local_idx + it * WARP_SIZE; int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) { if (element_index < element_count) {
// compute gradients // compute gradients
gradInput[i*element_count*stride+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i])); output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
} }
} }
} }
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h" #include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn { namespace multihead_attn {
namespace fused_softmax { namespace fused_softmax {
...@@ -46,15 +46,20 @@ torch::Tensor fwd_cuda( ...@@ -46,15 +46,20 @@ torch::Tensor fwd_cuda(
void* input_ptr = static_cast<void*>(input.data_ptr()); void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr()); void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
dispatch_scaled_upper_triang_masked_softmax_forward<half, half, float>( DISPATCH_HALF_AND_BFLOAT(
reinterpret_cast<half*>(softmax_results_ptr), input.scalar_type(),
reinterpret_cast<const half*>(input_ptr), "dispatch_scaled_upper_triang_masked_softmax_forward",
scale_factor, dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
seq_len, reinterpret_cast<scalar_t*>(softmax_results_ptr),
seq_len, reinterpret_cast<const scalar_t*>(input_ptr),
attn_batches); scale_factor,
seq_len,
seq_len,
attn_batches);
);
return softmax_results; return softmax_results;
} }
torch::Tensor bwd_cuda( torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_, torch::Tensor const& output_grads_,
...@@ -72,14 +77,18 @@ torch::Tensor bwd_cuda( ...@@ -72,14 +77,18 @@ torch::Tensor bwd_cuda(
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr()); void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad //Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward<half, half, float>( DISPATCH_HALF_AND_BFLOAT(
reinterpret_cast<half*>(output_grads_ptr), output_grads_.scalar_type(),
reinterpret_cast<half*>(output_grads_ptr), "dispatch_scaled_upper_triang_masked_softmax_backward",
reinterpret_cast<half const*>(softmax_results.data_ptr()), dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
scale_factor, reinterpret_cast<scalar_t*>(output_grads_ptr),
seq_len, reinterpret_cast<scalar_t*>(output_grads_ptr),
seq_len, reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
attn_batches); scale_factor,
seq_len,
seq_len,
attn_batches);
);
//backward pass is completely in-place //backward pass is completely in-place
return output_grads; return output_grads;
......
...@@ -14,214 +14,78 @@ ...@@ -14,214 +14,78 @@
* limitations under the License. * limitations under the License.
*/ */
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.h" #include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \ switch(TYPE) \
{ \ { \
case at::ScalarType::Float: \ case at::ScalarType::Half: \
{ \ { \
using scalar_t_##LEVEL = float; \ using scalar_t = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: \ case at::ScalarType::BFloat16: \
{ \ { \
using scalar_t_##LEVEL = at::Half; \ using scalar_t = at::BFloat16; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
default: \ default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template<typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op
(T *x,
T val,
int lanes=1,
bool share_result=false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
if(blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if(tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i]));
__syncthreads();
}
T final;
if(tid < 32)
{
if(blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for(int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if(share_result)
{
if(tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment