Commit d5f3875c authored by Shenggan's avatar Shenggan
Browse files

init commit

parent 62ed1c4a
......@@ -127,3 +127,10 @@ dmypy.json
# Pyre type checker
.pyre/
# vscode
.vscode/
# setup
dist/
build/
\ No newline at end of file
![](/assets/fold.jpg)
# FastFold
![](https://img.shields.io/github/v/release/hpcaitech/FastFold)
[![GitHub license](https://img.shields.io/github/license/hpcaitech/FastFold.svg)](https://github.com/hpcaitech/FastFold/blob/master/LICENSE)
![](https://img.shields.io/badge/Made%20with-ColossalAI-blueviolet?style=flat)
Optimizing Protein Structure Prediction Model Training and Inference on GPU Clusters
FastFold provides a **high-performance implementation of Evoformer** with the following characteristics.
1. Excellent kernel performance on GPU platform
2. Supporting Dynamic Axial Parallelism(DAP)
* Break the memory limit of single GPU and reduce the overall training time
* Distributed inference can significantly speed up inference and make extremely long sequence inference possible
3. Ease of use
* Replace a few lines and you can use FastFold in your project
* You don't need to care about how the parallel part is implemented
## Installation
You will need Python 3.8 or later and [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 11.1 or above when you are installing from source.
We highly recommend installing an Anaconda or Miniconda environment and install PyTorch with conda:
```
conda create -n fastfold python=3.8
conda activate fastfold
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
```
You can get the FastFold source and install it with setuptools:
```shell
git clone https://github.com/hpcaitech/FastFold
cd FastFold
python setup.py install --cuda_ext
```
## Performance Benchmark
We have included a performance benchmark script in `./benchmark`. You can benchmark the performance of Evoformer using different settings.
```shell
cd ./benchmark
torchrun --nproc_per_node=1 perf.py --msa-length 128 --res-length 256
```
If you want to benchmark with [OpenFold](https://github.com/aqlaboratory/openfold), you need to install OpenFold first and benchmark with option `--openfold`:
```shell
torchrun --nproc_per_node=1 perf.py --msa-length 128 --res-length 256 --openfold
```
## Cite us
Cite this paper, if you use FastFold in your research publication.
```
```
\ No newline at end of file
import argparse
import os
import torch
import torch.nn as nn
from fastfold.distributed import init_shadowcore
from fastfold.model import Evoformer
def main():
parser = argparse.ArgumentParser(description='MSA Attention Standalone Perf Benchmark')
parser.add_argument("--dap-size", default=1, type=int)
parser.add_argument('--batch-size', default=1, type=int, help='batch size')
parser.add_argument('--msa-length', default=132, type=int, help='Sequence Length of Input')
parser.add_argument('--res-length',
default=256,
type=int,
help='Start Range of Number of Sequences')
parser.add_argument('--trials', default=50, type=int, help='Number of Trials to Execute')
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
parser.add_argument('--layers',
default=12,
type=int,
help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--cm', default=256, type=int, help='MSA hidden dimension')
parser.add_argument('--cz', default=128, type=int, help='Pair hidden dimension')
parser.add_argument('--heads', default=8, type=int, help='Number of Multihead Attention heads')
parser.add_argument('--openfold',
action='store_true',
help='torch.nn.MultitheadAttention Version.')
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
parser.add_argument('--prof', action='store_true', help='Only execute Fwd Pass.')
args = parser.parse_args()
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.global_rank = torch.distributed.get_rank()
print(
'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.global_rank, args.world_size))
init_shadowcore(args.tensor_model_parallel_size)
precision = torch.bfloat16
if args.tensor_model_parallel_size > 1:
# (PyTorch issue) Currently All2All communication does not support the Bfloat16 datatype in PyTorch
precision = torch.float16
if not torch.cuda.is_available():
raise NotImplementedError('Running on CPU is not supported')
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
if args.openfold:
from openfold.model.evoformer import EvoformerBlock
class OpenFoldEvoformer(nn.Module):
def __init__(self, d_node, d_pair):
super(OpenFoldEvoformer, self).__init__()
self.d_node = d_node
self.d_pair = d_pair
self.c_hidden_msa_att = int(d_node / 8)
self.c_hidden_pair_att = int(d_pair / 8)
self.EvoformerBlock = EvoformerBlock(c_m=d_node,
c_z=d_pair,
c_hidden_msa_att=self.c_hidden_msa_att,
c_hidden_opm=self.c_hidden_msa_att,
c_hidden_mul=self.d_pair,
c_hidden_pair_att=self.c_hidden_pair_att,
no_heads_msa=8,
no_heads_pair=4,
transition_n=4,
msa_dropout=0.15,
pair_dropout=0.25,
inf=1e9,
eps=1e-10)
def forward(self, node, pair, node_mask, pair_mask):
node, pair = self.EvoformerBlock(node, pair, node_mask, pair_mask)
return node, pair
attn_layers = []
for idx in range(0, args.layers):
if args.openfold:
attn_layers.append(OpenFoldEvoformer(d_node=args.cm, d_pair=args.cz))
else:
attn_layers.append(Evoformer(d_node=args.cm, d_pair=args.cz))
attn_layers[idx].cuda()
attn_layers[idx].to(dtype=precision)
start_evt_fwd = []
start_evt_bwd = []
stop_evt_bwd = []
for recorded_trial in range(0, args.trials):
start_evt_fwd.append(torch.cuda.Event(enable_timing=True))
start_evt_bwd.append(torch.cuda.Event(enable_timing=True))
stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))
inputs_node = torch.randn(args.batch_size,
args.msa_length // args.tensor_model_parallel_size,
args.res_length,
args.cm,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
inputs_pair = torch.randn(args.batch_size,
args.res_length // args.tensor_model_parallel_size,
args.res_length,
args.cz,
dtype=precision,
device=torch.device("cuda")).requires_grad_(True)
node_mask = torch.ones((args.batch_size, args.msa_length, args.res_length),
dtype=precision,
device=torch.device("cuda")).requires_grad_(False)
pair_mask = torch.ones((args.batch_size, args.res_length, args.res_length),
dtype=precision,
device=torch.device("cuda")).requires_grad_(False)
grads_node = torch.randn_like(inputs_pair)
if args.prof:
prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1,
warmup=args.warmup_trials,
active=args.trials,
repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/fastfold'),
profile_memory=False,
record_shapes=False,
with_stack=False)
prof.start()
for trial in range(0, args.trials + args.warmup_trials):
layer_inputs = inputs_node, inputs_pair
evt_idx = trial - args.warmup_trials
torch.distributed.barrier()
torch.cuda.synchronize()
if evt_idx >= 0:
start_evt_fwd[evt_idx].record()
for lyr_idx in range(0, args.layers):
layer_inputs = attn_layers[lyr_idx].forward(*layer_inputs, node_mask, pair_mask)
torch.cuda.synchronize()
if evt_idx >= 0:
start_evt_bwd[evt_idx].record()
if not args.fwd:
layer_inputs[1].backward(grads_node)
if evt_idx >= 0:
stop_evt_bwd[evt_idx].record()
if args.prof:
prof.step()
if args.prof:
prof.stop()
torch.distributed.barrier()
torch.cuda.synchronize()
elapsed_time_fwd = 0.0
elapsed_time_bwd = 0.0
for evt_idx in range(0, args.trials):
elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx])
elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])
print("[ MSA Attn ] Input: {:4d}, {:4d}, {:4d}, ({:4d} {:4d}) Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format(
args.batch_size, args.msa_length, args.res_length, \
args.cm, args.cz, \
elapsed_time_fwd / ( args.trials * args.layers ), \
elapsed_time_bwd / ( args.trials * args.layers )))
if __name__ == '__main__':
main()
from .core import (init_shadowcore, shadowcore_is_initialized, get_tensor_model_parallel_group,
get_data_parallel_group, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank, get_data_parallel_world_size,
get_data_parallel_rank, get_tensor_model_parallel_src_rank)
from .comm import (_reduce, _split, _gather, copy, scatter, reduce, gather, col_to_row, row_to_col)
__all__ = [
'init_shadowcore', 'shadowcore_is_initialized', 'get_tensor_model_parallel_group',
'get_data_parallel_group', 'get_tensor_model_parallel_world_size',
'get_tensor_model_parallel_rank', 'get_data_parallel_world_size', 'get_data_parallel_rank',
'get_tensor_model_parallel_src_rank', '_reduce', '_split', '_gather', 'copy', 'scatter',
'reduce', 'gather', 'col_to_row', 'row_to_col'
]
\ No newline at end of file
from typing import Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from .core import (get_tensor_model_parallel_group, get_tensor_model_parallel_src_rank,
get_tensor_model_parallel_world_size)
from .core import ensure_divisibility
def divide(numerator, denominator):
ensure_divisibility(numerator, denominator)
return numerator // denominator
def _reduce(tensor: Tensor) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor
dist.all_reduce(tensor,
op=dist.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
async_op=False)
return tensor
def _split(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor
split_size = divide(tensor.shape[dim], get_tensor_model_parallel_world_size())
tensor_list = torch.split(tensor, split_size, dim=dim)
output = tensor_list[get_tensor_model_parallel_src_rank()].contiguous()
return output
def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor
if dim == 1:
output_shape = list(tensor.shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
dist.all_gather(list(tensor_list), tensor, group=get_tensor_model_parallel_group(), async_op=False)
else:
tensor_list = [torch.ones_like(tensor) for _ in range(get_tensor_model_parallel_world_size())]
dist.all_gather(tensor_list, tensor, group=get_tensor_model_parallel_group(), async_op=False)
output = torch.cat(tensor_list, dim=dim)
return output
def copy(input: Tensor) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Copy.apply(input)
return input
class Copy(torch.autograd.Function):
@staticmethod
def forward(ctx: "Copy", input: Tensor) -> Tensor:
return input
@staticmethod
def backward(ctx: "Copy", grad_output: Tensor) -> Tensor:
return _reduce(grad_output)
def scatter(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Scatter.apply(input, dim)
else:
input = _split(input, dim=dim)
return input
class Scatter(torch.autograd.Function):
@staticmethod
def forward(ctx: "Scatter", input: Tensor, dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([dim]))
return _split(input, dim=dim)
@staticmethod
def backward(ctx: "Scatter", grad_output: Tensor) -> Tuple[Tensor]:
dim, = ctx.saved_tensors
return _gather(grad_output, dim=int(dim)), None
def reduce(input: Tensor) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Reduce.apply(input)
else:
input = _reduce(input)
return input
class Reduce(torch.autograd.Function):
@staticmethod
def forward(ctx: "Reduce", input: Tensor) -> Tensor:
return _reduce(input)
@staticmethod
def backward(ctx: "Reduce", grad_output: Tensor) -> Tensor:
return grad_output
def gather(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Gather.apply(input, dim)
else:
input = _gather(input, dim=dim)
return input
class Gather(torch.autograd.Function):
@staticmethod
def forward(ctx: "Gather", input: Tensor, dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([dim]))
return _gather(input, dim=dim)
@staticmethod
def backward(ctx: "Gather", grad_output: Tensor) -> Tuple[Tensor]:
dim, = ctx.saved_tensors
return _split(grad_output, dim=int(dim)), None
def _all_to_all(tensor: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor
split_size = divide(tensor.shape[in_dim], get_tensor_model_parallel_world_size())
input_tensor_list = torch.split(tensor, split_size, dim=in_dim)
input_tensor_list = [tensor_.contiguous() for tensor_ in input_tensor_list]
if out_dim == 1:
output_shape = list(input_tensor_list[0].shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
output_tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
dist.all_to_all(list(output_tensor_list),
input_tensor_list,
group=get_tensor_model_parallel_group(),
async_op=False)
else:
output_tensor_list = [torch.ones_like(tensor_) for tensor_ in input_tensor_list]
dist.all_to_all(output_tensor_list,
input_tensor_list,
group=get_tensor_model_parallel_group(),
async_op=False)
output = torch.cat(output_tensor_list, dim=out_dim)
return output
def col_to_row(input_: Tensor) -> Tensor:
if torch.is_grad_enabled() and input_.requires_grad:
input_ = All_to_All.apply(input_, 1, 2)
else:
input_ = _all_to_all(input_, in_dim=1, out_dim=2)
return input_
def row_to_col(input_: Tensor) -> Tensor:
if torch.is_grad_enabled() and input_.requires_grad:
input_ = All_to_All.apply(input_, 2, 1)
else:
input_ = _all_to_all(input_, in_dim=2, out_dim=1)
return input_
class All_to_All(torch.autograd.Function):
@staticmethod
def forward(ctx: "All_to_All", input_: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([in_dim, out_dim]))
return _all_to_all(input_, in_dim=in_dim, out_dim=out_dim)
@staticmethod
def backward(ctx: "All_to_All", grad_output: Tensor) -> Tuple[Tensor]:
saved_tensors = ctx.saved_tensors[0]
return _all_to_all(grad_output, in_dim=int(saved_tensors[1]),
out_dim=int(saved_tensors[0])), None, None
from typing import Tuple
from einops import rearrange
import torch
import torch.distributed as dist
from torch import Tensor
from .core import get_tensor_model_parallel_world_size, get_tensor_model_parallel_group
from .comm import _split, divide
def _gather_async(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor, None
output_shape = list(tensor.shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
work = dist.all_gather(list(tensor_list),
tensor,
group=get_tensor_model_parallel_group(),
async_op=True)
return output, work
def gather_async(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input, work = GatherAsync.apply(input, dim)
else:
input, work = _gather_async(input, dim=dim)
return input, work
def gather_async_opp(output: Tensor, work, dim: int = -1) -> Tensor:
if work:
work.wait()
if dim == 2:
output = GatherAsyncOpp.apply(output)
return output
class GatherAsyncOpp(torch.autograd.Function):
@staticmethod
def forward(ctx: "GatherAsyncOpp", input: Tensor) -> Tensor:
mp_size = get_tensor_model_parallel_world_size()
output = rearrange(input, 'n (x h) w c -> n h (x w) c', x=mp_size)
return output
@staticmethod
def backward(ctx: "GatherAsyncOpp", grad_output: Tensor) -> Tuple[Tensor]:
mp_size = get_tensor_model_parallel_world_size()
n, h, w, c = grad_output.shape
return grad_output.resize_(n, h * mp_size, int(w / mp_size), c)
class GatherAsync(torch.autograd.Function):
@staticmethod
def forward(ctx: "GatherAsync", input: Tensor, dim: int = -1) -> Tensor:
ctx.dim = dim
return _gather_async(input, dim=dim)
@staticmethod
def backward(ctx: "GatherAsync", grad_output: Tensor, grad_work=None) -> Tuple[Tensor]:
if ctx.dim == 2:
mp_size = get_tensor_model_parallel_world_size()
n, h, w, c = grad_output.shape
grad_output.resize_(n, int(h / mp_size), w * mp_size, c)
return _split(grad_output, dim=ctx.dim), None
def _all_to_all_async(tensor: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor, None
split_size = divide(tensor.shape[in_dim], get_tensor_model_parallel_world_size())
input_tensor_list = torch.split(tensor, split_size, dim=in_dim)
input_tensor_list = [tensor_.contiguous() for tensor_ in input_tensor_list]
output_shape = list(input_tensor_list[0].shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
output_tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
work = dist.all_to_all(list(output_tensor_list),
input_tensor_list,
group=get_tensor_model_parallel_group(),
async_op=True)
return output, work
WORLD_WORK_ALL2ALL = None
class All_to_All_Async(torch.autograd.Function):
@staticmethod
def forward(ctx: "All_to_All_Async",
input_: Tensor,
in_dim: int = -1,
out_dim: int = -1) -> Tensor:
ctx.in_dim = in_dim
ctx.out_dim = out_dim
return _all_to_all_async(input_, in_dim=in_dim, out_dim=out_dim)
@staticmethod
def backward(ctx: "All_to_All_Async", grad_output: Tensor, grad_work=None) -> Tuple[Tensor]:
global WORLD_WORK_ALL2ALL
if WORLD_WORK_ALL2ALL:
WORLD_WORK_ALL2ALL.wait()
WORLD_WORK_ALL2ALL = None
if ctx.in_dim == 2:
mp_size = get_tensor_model_parallel_world_size()
grad_output = rearrange(grad_output, 'n (x h) w c -> n h (x w) c', x=mp_size)
return grad_output, None, None
class All_to_All_Async_Opp(torch.autograd.Function):
@staticmethod
def forward(ctx: "All_to_All_Async_Opp",
output: Tensor,
work,
in_dim: int = -1,
out_dim: int = -1) -> Tensor:
ctx.in_dim = in_dim
ctx.out_dim = out_dim
if work:
work.wait()
if out_dim == 2:
mp_size = get_tensor_model_parallel_world_size()
output = rearrange(output, 'n (x h) w c -> n h (x w) c', x=mp_size)
return output
@staticmethod
def backward(ctx: "All_to_All_Async_Opp", grad_output: Tensor) -> Tuple[Tensor]:
global WORLD_WORK_ALL2ALL
d_tensor, WORLD_WORK_ALL2ALL = _all_to_all_async(grad_output,
in_dim=ctx.out_dim,
out_dim=ctx.in_dim)
return d_tensor, None, None, None
\ No newline at end of file
import torch.distributed as dist
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly.
_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_TENSOR_MODEL_PARALLEL_RANK = None
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)
def init_shadowcore(tensor_model_parallel_size_=1):
assert dist.is_initialized()
world_size = dist.get_world_size()
rank = dist.get_rank()
# check dist config
ensure_divisibility(world_size, tensor_model_parallel_size_)
data_parallel_size_ = world_size // tensor_model_parallel_size_
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
for i in range(tensor_model_parallel_size_):
ranks = range(i, world_size, tensor_model_parallel_size_)
group = dist.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'tensor model parallel group is already initialized'
# Build the model-parallel groups.
for i in range(data_parallel_size_):
ranks = range(i * tensor_model_parallel_size_, (i + 1) * tensor_model_parallel_size_)
group = dist.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
if dist.get_rank() == 0:
print('> initialize tensor model parallel with size {}'.format(tensor_model_parallel_size_))
print('> initialize data parallel with size {}'.format(data_parallel_size_))
def shadowcore_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
return False
return True
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
return _TENSOR_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _TENSOR_MODEL_PARALLEL_WORLD_SIZE
return dist.get_world_size(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _TENSOR_MODEL_PARALLEL_RANK
if _TENSOR_MODEL_PARALLEL_RANK is not None:
return _TENSOR_MODEL_PARALLEL_RANK
return dist.get_rank(group=get_tensor_model_parallel_group())
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return dist.get_world_size(group=get_data_parallel_group())
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return dist.get_rank(group=get_data_parallel_group())
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = dist.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
from .msa import MSAStack
from .ops import OutProductMean
from .triangle import PairStack
from .evoformer import Evoformer
__all__ = ['MSAStack', 'OutProductMean', 'PairStack', 'Evoformer']
import torch.nn as nn
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.model import MSAStack, OutProductMean, PairStack
class Evoformer(nn.Module):
def __init__(self, d_node, d_pair):
super(Evoformer, self).__init__()
self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15)
self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=d_pair)
def forward(self, node, pair, node_mask, pair_mask):
node = self.msa_stack(node, pair, node_mask)
pair = pair + self.communication(node, node_mask)
node, work = All_to_All_Async.apply(node, 1, 2)
pair = self.pair_stack(pair, pair_mask)
node = All_to_All_Async_Opp.apply(node, work, 1, 2)
return node, pair
\ No newline at end of file
import math
import numpy as np
import torch.nn as nn
def glorot_uniform_af(x, gain=1.0):
"""
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
In PyTorch:
[feature_out, feature_in, n_head ...]
In Jax:
[... n_head, feature_in, feature_out]
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
[feature_in, n_head, feature_out]
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
"""
fan_in, fan_out = x.shape[-2:]
if len(x.shape) > 2:
receptive_field_size = np.prod(x.shape[:-2])
fan_in *= receptive_field_size
fan_out *= receptive_field_size
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
nn.init.uniform_(x, -dev, dev)
return x
from .jit.fused_ops import bias_dropout_add, bias_sigmod_ele, bias_ele_dropout_residual
from .cuda_native.layer_norm import MixedFusedLayerNorm as LayerNorm
from .cuda_native.softmax import softmax, scale_mask_softmax, scale_mask_bias_softmax
__all__ = [
"bias_dropout_add", "bias_sigmod_ele", "bias_ele_dropout_residual", "LayerNorm", "softmax",
"scale_mask_softmax", "scale_mask_bias_softmax"
]
\ No newline at end of file
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
\ No newline at end of file
#ifndef FASTFOLD_LAYER_NORM_H_
#define FASTFOLD_LAYER_NORM_H_
#include <assert.h>
#include <math_constants.h>
#include <cub/cub.cuh>
namespace fastfold {
namespace layer_norm {
constexpr int kWarpSize = 32;
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; }
};
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); }
};
template <template <typename> class ReductionOp, typename T, int thread_group_width = kWarpSize>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <template <typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
if (threadIdx.x == 0) {
result_broadcast = result;
}
__syncthreads();
return result_broadcast;
}
template <typename T>
__inline__ __device__ T Div(T a, T b);
template <>
__inline__ __device__ at::BFloat16 Div<at::BFloat16>(at::BFloat16 a, at::BFloat16 b) {
return a / b;
}
template <>
__inline__ __device__ float Div<float>(float a, float b) {
return __fdividef(a, b);
}
template <>
__inline__ __device__ double Div<double>(double a, double b) {
return a / b;
}
template <typename T>
__inline__ __device__ T Rsqrt(T x);
template <>
__inline__ __device__ at::BFloat16 Rsqrt<at::BFloat16>(at::BFloat16 x) {
return rsqrt(x);
}
template <>
__inline__ __device__ float Rsqrt<float>(float x) {
return rsqrt(x);
}
template <>
__inline__ __device__ double Rsqrt<double>(double x) {
return rsqrt(x);
}
template <class Func>
inline cudaError_t GetNumBlocks(Func func, int64_t block_size, size_t dynamic_smem_size,
int64_t max_blocks, int64_t waves, int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
}
int sm_count;
{
cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
}
int max_active_blocks;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, func, block_size, dynamic_smem_size);
}
*num_blocks =
std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * max_active_blocks * waves));
return cudaSuccess;
}
template <typename T>
struct DefaultComputeType {
using type = T;
};
template <>
struct DefaultComputeType<half> {
using type = float;
};
#if CUDA_VERSION >= 11000
template <>
struct DefaultComputeType<nv_bfloat16> {
using type = float;
};
#endif // CUDA_VERSION >= 11000
template <typename T, int N>
struct GetPackType {
using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type;
};
template <typename T, int N>
using PackType = typename GetPackType<T, N>::type;
template <typename T, int N>
union Pack {
static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, "");
__device__ Pack() {
// do nothing
}
PackType<T, N> storage;
T elem[N];
};
template <typename SRC, typename DST>
struct DirectLoad {
DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {}
template <int N>
__device__ void load(DST* dst, int64_t row, int64_t col) const {
Pack<SRC, N> pack;
const int64_t offset = (row * row_size + col) / N;
pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(src) + offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = static_cast<DST>(pack.elem[i]);
}
}
const SRC* src;
int64_t row_size;
};
template <typename SRC, typename DST>
struct DirectStore {
DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
Pack<DST, N> pack;
const int64_t offset = (row * row_size + col) / N;
#pragma unroll
for (int i = 0; i < N; ++i) {
pack.elem[i] = static_cast<DST>(src[i]);
}
*(reinterpret_cast<PackType<DST, N>*>(dst) + offset) = pack.storage;
}
DST* dst;
int64_t row_size;
};
template <typename SRC, typename DST, bool do_scale, bool do_center>
struct AffineStore {
AffineStore(DST* normalized, DST* y, int64_t row_size, const DST* gamma, const DST* beta)
: normalized(normalized), y(y), row_size(row_size), gamma(gamma), beta(beta) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
Pack<DST, N> y_pack;
Pack<DST, N> normalized_pack;
Pack<DST, N> gamma_pack;
Pack<DST, N> beta_pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t gamma_offset = col / N;
if (do_scale) {
gamma_pack.storage = *(reinterpret_cast<const PackType<DST, N>*>(gamma) + gamma_offset);
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
gamma_pack.elem[i] = 1;
}
}
if (do_center) {
beta_pack.storage = *(reinterpret_cast<const PackType<DST, N>*>(beta) + gamma_offset);
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
beta_pack.elem[i] = 0;
}
}
#pragma unroll
for (int i = 0; i < N; ++i) {
DST normalized_i = static_cast<DST>(src[i]);
if (do_scale) {
normalized_pack.elem[i] = normalized_i;
}
if (do_scale || do_center) {
y_pack.elem[i] = normalized_i * gamma_pack.elem[i] + beta_pack.elem[i];
} else {
y_pack.elem[i] = normalized_i;
}
}
*(reinterpret_cast<PackType<DST, N>*>(y) + offset) = y_pack.storage;
if (do_scale) {
*(reinterpret_cast<PackType<DST, N>*>(normalized) + offset) = normalized_pack.storage;
}
}
DST* normalized;
DST* y;
int64_t row_size;
const DST* gamma;
const DST* beta;
};
template <typename SRC, typename DST, bool do_scale>
struct ScaleLoad {
ScaleLoad(const SRC* src, const SRC* gamma, int64_t row_size)
: src(src), gamma(gamma), row_size(row_size) {}
template <int N>
__device__ void load(DST* dst, int64_t row, int64_t col) const {
Pack<SRC, N> src_pack;
Pack<SRC, N> gamma_pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t gamma_offset = col / N;
src_pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(src) + offset);
if (do_scale) {
gamma_pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(gamma) + gamma_offset);
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
gamma_pack.elem[i] = static_cast<SRC>(1);
}
}
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = static_cast<DST>(src_pack.elem[i] * gamma_pack.elem[i]);
}
}
const SRC* src;
const SRC* gamma;
int64_t row_size;
};
template <typename SRC, typename DST, bool do_add>
struct AddStore {
AddStore(const DST* add_to_output, DST* dst, int64_t row_size)
: add_to_output(add_to_output), dst(dst), row_size(row_size) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
Pack<DST, N> add_to_output_pack;
Pack<DST, N> dst_pack;
const int64_t offset = (row * row_size + col) / N;
if (do_add) {
add_to_output_pack.storage =
*(reinterpret_cast<const PackType<DST, N>*>(add_to_output) + offset);
}
#pragma unroll
for (int i = 0; i < N; ++i) {
if (do_add) {
dst_pack.elem[i] = static_cast<DST>(src[i]) + add_to_output_pack.elem[i];
} else {
dst_pack.elem[i] = static_cast<DST>(src[i]);
}
}
*(reinterpret_cast<PackType<DST, N>*>(dst) + offset) = dst_pack.storage;
}
const DST* add_to_output;
DST* dst;
int64_t row_size;
};
template <typename T>
inline __device__ void WelfordCombine(T val, T* mean, T* m2, T* count) {
// Use Welford Online algorithem to compute mean and variance
// For more details you can refer to:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
*count += 1;
T delta1 = val - *mean;
*mean += Div(delta1, *count);
T delta2 = val - *mean;
*m2 += delta1 * delta2;
}
template <typename T>
inline __device__ void WelfordCombine(T b_mean, T b_m2, T b_count, T* mean, T* m2, T* count) {
if (b_count == 0) {
return;
}
T new_count = *count + b_count;
T nb_over_n = Div(b_count, new_count);
T delta = b_mean - *mean;
*mean += delta * nb_over_n;
*m2 += b_m2 + delta * delta * (*count) * nb_over_n;
*count = new_count;
}
template <typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpReduce(T thread_mean, T thread_m2, T thread_count, T* mean,
T* m2, T* count) {
*mean = thread_mean;
*m2 = thread_m2;
*count = thread_count;
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
T b_mean = __shfl_down_sync(0xffffffff, *mean, mask);
T b_m2 = __shfl_down_sync(0xffffffff, *m2, mask);
T b_count = __shfl_down_sync(0xffffffff, *count, mask);
WelfordCombine(b_mean, b_m2, b_count, mean, m2, count);
}
}
template <typename T, int thread_group_width = kWarpSize>
__inline__ __device__ void WelfordWarpAllReduce(T thread_mean, T thread_m2, T thread_count, T* mean,
T* m2, T* count) {
WelfordWarpReduce<T, thread_group_width>(thread_mean, thread_m2, thread_count, mean, m2, count);
*mean = __shfl_sync(0xffffffff, *mean, 0, thread_group_width);
*m2 = __shfl_sync(0xffffffff, *m2, 0, thread_group_width);
*count = __shfl_sync(0xffffffff, *count, 0, thread_group_width);
}
template <typename T>
__inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T thread_count,
T* result_mean, T* result_m2, T* result_count) {
__shared__ T mean_shared[kWarpSize];
__shared__ T m2_shared[kWarpSize];
__shared__ T count_shared[kWarpSize];
__shared__ T mean_result_broadcast;
__shared__ T m2_result_broadcast;
__shared__ T count_result_broadcast;
const int lid = threadIdx.x % kWarpSize;
const int wid = threadIdx.x / kWarpSize;
T warp_mean = 0;
T warp_m2 = 0;
T warp_count = 0;
WelfordWarpReduce(thread_mean, thread_m2, thread_count, &warp_mean, &warp_m2, &warp_count);
__syncthreads();
if (lid == 0) {
mean_shared[wid] = warp_mean;
m2_shared[wid] = warp_m2;
count_shared[wid] = warp_count;
}
__syncthreads();
if (wid == 0) {
if (threadIdx.x < blockDim.x / kWarpSize) {
warp_mean = mean_shared[lid];
warp_m2 = m2_shared[lid];
warp_count = count_shared[lid];
} else {
warp_mean = static_cast<T>(0);
warp_m2 = static_cast<T>(0);
warp_count = static_cast<T>(0);
}
__syncwarp();
T block_mean = 0;
T block_m2 = 0;
T block_count = 0;
WelfordWarpReduce(warp_mean, warp_m2, warp_count, &block_mean, &block_m2, &block_count);
if (lid == 0) {
mean_result_broadcast = block_mean;
m2_result_broadcast = block_m2;
count_result_broadcast = block_count;
}
}
__syncthreads();
*result_mean = mean_result_broadcast;
*result_m2 = m2_result_broadcast;
*result_count = count_result_broadcast;
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, bool padding>
__global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
static_assert(cols_per_thread % pack_size == 0, "");
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
constexpr int num_packs = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * thread_group_width);
ComputeType buf[rows_per_access][cols_per_thread];
const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int64_t num_global_thread_group = gridDim.x * blockDim.y;
const int64_t lane_id = threadIdx.x;
const int64_t step = num_global_thread_group * rows_per_access;
for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {
ComputeType thread_mean[rows_per_access];
ComputeType thread_m2[rows_per_access];
ComputeType thread_count[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
thread_mean[row_id] = 0;
thread_m2[row_id] = 0;
thread_count[row_id] = 0;
ComputeType* row_buf = buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < num_packs; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
const int pack_offset = pack_id * pack_size;
if (!padding || col < cols) {
load.template load<pack_size>(row_buf + pack_offset, row + row_id, col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
WelfordCombine(row_buf[pack_offset + i], thread_mean + row_id,
thread_m2 + row_id, thread_count + row_id);
}
} else {
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
row_buf[pack_offset + i] = 0;
}
}
}
}
ComputeType warp_mean[rows_per_access];
ComputeType warp_m2[rows_per_access];
ComputeType warp_count[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
int global_row_id = row + row_id;
ComputeType* row_buf = buf[row_id];
WelfordWarpAllReduce<ComputeType, thread_group_width>(
thread_mean[row_id], thread_m2[row_id], thread_count[row_id], warp_mean + row_id,
warp_m2 + row_id, warp_count + row_id);
ComputeType row_mean = warp_mean[row_id];
ComputeType row_variance =
max(Div(warp_m2[row_id], warp_count[row_id]), static_cast<ComputeType>(0.0));
ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));
if (lane_id == 0) {
mean[global_row_id] = row_mean;
inv_variance[global_row_id] = row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_buf[i] = (row_buf[i] - row_mean) * row_inv_var;
}
#pragma unroll
for (int i = 0; i < num_packs; ++i) {
const int col = (i * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
store.template store<pack_size>(row_buf + i * pack_size, global_row_id, col);
}
}
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, bool padding>
inline cudaError_t LaunchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
const int64_t num_blocks =
(rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
cudaError_t err =
GetNumBlocks(LayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, padding>,
block_size, 0, num_blocks, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread, thread_group_width,
rows_per_access, padding><<<grid_dim_x, block_dim, 0, stream>>>(
load, store, rows, cols, epsilon, mean, inv_variance);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access>
inline cudaError_t DispatchLayerNormWarpImplPadding(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
if (cols == cols_per_thread * thread_group_width) {
return LaunchLayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, false>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else {
return LaunchLayerNormWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, true>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size>
typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchLayerNormWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, \
pack_size, thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, \
pack_size, thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, \
kWarpSize, 1>(stream, load, store, rows, cols, \
epsilon, mean, inv_variance); \
}
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(3)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(5)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(7)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(9)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(11)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(13)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(15)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(17)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(19)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(21)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(23)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(25)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(27)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(29)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(31)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size>
typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchLayerNormWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, \
pack_size, thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, \
pack_size, thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, \
kWarpSize, 1>(stream, load, store, rows, cols, \
epsilon, mean, inv_variance); \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size>
typename std::enable_if<pack_size == 4, cudaError_t>::type DispatchLayerNormWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, \
pack_size, thread_group_width, 2>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} else { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, \
pack_size, thread_group_width, 1>( \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, \
kWarpSize, 1>(stream, load, store, rows, cols, \
epsilon, mean, inv_variance); \
}
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD, typename STORE, typename ComputeType>
struct DispatchLayerNormWarpImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
if (cols % 4 == 0) {
return DispatchLayerNormWarpImplCols<LOAD, STORE, ComputeType, 4>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else if (cols % 2 == 0) {
return DispatchLayerNormWarpImplCols<LOAD, STORE, ComputeType, 2>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else {
return DispatchLayerNormWarpImplCols<LOAD, STORE, ComputeType, 1>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormWarpImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
return DispatchLayerNormWarpImplPackSize<LOAD, STORE, ComputeType>()(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
__global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto* buf = reinterpret_cast<ComputeType*>(shared_buf);
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_mean = 0;
ComputeType thread_m2 = 0;
ComputeType thread_count = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
buf[i * num_packs + pack_id] = pack[i];
WelfordCombine(pack[i], &thread_mean, &thread_m2, &thread_count);
}
}
ComputeType row_mean = 0;
ComputeType row_m2 = 0;
ComputeType row_count = 0;
WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,
&row_count);
ComputeType row_variance = max(Div(row_m2, row_count), static_cast<ComputeType>(0.0));
ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));
if (threadIdx.x == 0) {
mean[row] = row_mean;
inv_variance[row] = row_inv_var;
}
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack[i] = (buf[i * num_packs + pack_id] - row_mean) * row_inv_var;
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
inline cudaError_t LaunchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,
int smem, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err =
GetNumBlocks(LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>,
block_size, smem, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, smem, stream>>>(load, store, rows, cols, epsilon, mean,
inv_variance);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size>
inline cudaError_t TryDispatchLayerNormBlockSMemImplBlockSize(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean, ComputeType* inv_variance, bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem = cols * sizeof(ComputeType);
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_1,
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1>,
block_size_conf_1, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_4,
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4>,
block_size_conf_4, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4>(
stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_3,
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3>,
block_size_conf_3, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3>(
stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_2,
LayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2>,
block_size_conf_2, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2>(
stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);
}
*success = true;
return LaunchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1>(
stream, load, store, smem, rows, cols, epsilon, mean, inv_variance);
}
template <typename LOAD, typename STORE, typename ComputeType>
struct TryDispatchLayerNormBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance, bool* success) {
if (cols % 4 == 0) {
return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 4>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, success);
} else if (cols % 2 == 0) {
return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 2>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, success);
} else {
return TryDispatchLayerNormBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 1>(
stream, load, store, rows, cols, epsilon, mean, inv_variance, success);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t TryDispatchLayerNormBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance, bool* success) {
return TryDispatchLayerNormBlockSMemImplPackSize<LOAD, STORE, ComputeType>()(
stream, load, store, rows, cols, epsilon, mean, inv_variance, success);
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size>
__global__ void LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon,
ComputeType* mean, ComputeType* inv_variance) {
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_mean = 0;
ComputeType thread_m2 = 0;
ComputeType thread_count = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
WelfordCombine(pack[i], &thread_mean, &thread_m2, &thread_count);
}
}
ComputeType row_mean = 0;
ComputeType row_m2 = 0;
ComputeType row_count = 0;
WelfordBlockAllReduce<ComputeType>(thread_mean, thread_m2, thread_count, &row_mean, &row_m2,
&row_count);
ComputeType row_variance = max(Div(row_m2, row_count), static_cast<ComputeType>(0.0));
ComputeType row_inv_var = Rsqrt(row_variance + static_cast<ComputeType>(epsilon));
if (threadIdx.x == 0) {
mean[row] = row_mean;
inv_variance[row] = row_inv_var;
}
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
const int pack_offset = pack_id * pack_size;
load.template load<pack_size>(pack, row, pack_offset);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack[i] = (pack[i] - row_mean) * row_inv_var;
}
store.template store<pack_size>(pack, row, pack_offset);
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size>
inline cudaError_t LaunchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(
LayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>, block_size,
0, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, 0, stream>>>(load, store, rows, cols, epsilon, mean,
inv_variance);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType>
struct DispatchLayerNormBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
if (cols % 4 == 0) {
return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 4>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else if (cols % 2 == 0) {
return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 2>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
} else {
return LaunchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType, 1>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
return DispatchLayerNormBlockUncachedImplPackSize<LOAD, STORE, ComputeType>()(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
if (cols <= 1024) {
return DispatchLayerNormWarpImpl<LOAD, STORE, ComputeType>(stream, load, store, rows, cols,
epsilon, mean, inv_variance);
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err = TryDispatchLayerNormBlockSMemImpl<LOAD, STORE, ComputeType>(
stream, load, store, rows, cols, epsilon, mean, inv_variance,
&dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
return cudaSuccess;
}
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, const double epsilon, ComputeType* mean,
ComputeType* inv_variance) {
return DispatchLayerNormBlockUncachedImpl<LOAD, STORE, ComputeType>(
stream, load, store, rows, cols, epsilon, mean, inv_variance);
}
/*
LayerNormGrad dx:
normalized = (x - mean) * inv_var
sum_stats1 = sum(scaled_dy)
sum_stats2 = sum(scaled_dy * normalized)
dx = cols * dy - sum_stats1 - normalized * sum_stats2
dx *= inv_var / cols
*/
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int cols_per_thread, int thread_group_width, int rows_per_access,
bool padding>
__global__ void LayerNormGradWarpImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
static_assert(cols_per_thread % pack_size == 0, "");
constexpr int pack_per_thread = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * thread_group_width);
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
ComputeType normalized_buf[rows_per_access][cols_per_thread];
ComputeType dy_buf[rows_per_access][cols_per_thread];
const ComputeType one_over_cols =
static_cast<ComputeType>(1.0) / static_cast<ComputeType>(cols);
const int64_t global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int64_t num_global_thread_group = gridDim.x * blockDim.y;
const int lane_id = threadIdx.x;
const int64_t step = num_global_thread_group * rows_per_access;
for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {
ComputeType sum_stats1[rows_per_access];
ComputeType sum_stats2[rows_per_access];
ComputeType inv_variance_buf[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
const int global_row_id = row + row_id;
ComputeType mean_val = mean[global_row_id];
inv_variance_buf[row_id] = inv_variance[global_row_id];
sum_stats1[row_id] = 0;
sum_stats2[row_id] = 0;
ComputeType* row_normalized_buf = normalized_buf[row_id];
ComputeType* row_dy_buf = dy_buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
const int pack_offset = pack_id * pack_size;
if (!padding || col < cols) {
load_x.template load<pack_size>(row_normalized_buf + pack_offset, global_row_id,
col);
load_scaled_dy.template load<pack_size>(row_dy_buf + pack_offset, global_row_id,
col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
const int col_id = pack_offset + i;
// row_normalized_buf store x
row_normalized_buf[col_id] =
(row_normalized_buf[col_id] - mean_val) * inv_variance_buf[row_id];
sum_stats1[row_id] += row_dy_buf[col_id];
sum_stats2[row_id] += row_dy_buf[col_id] * row_normalized_buf[col_id];
}
}
}
}
ComputeType warp_sum_stats1[rows_per_access];
ComputeType warp_sum_stats2[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
warp_sum_stats1[row_id] =
WarpAllReduce<SumOp, ComputeType, thread_group_width>(sum_stats1[row_id]);
warp_sum_stats2[row_id] =
WarpAllReduce<SumOp, ComputeType, thread_group_width>(sum_stats2[row_id]);
}
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
const int global_row_id = row + row_id;
ComputeType* row_normalized_buf = normalized_buf[row_id];
ComputeType* row_dy_buf = dy_buf[row_id];
const ComputeType inv_variance_over_cols = inv_variance_buf[row_id] * one_over_cols;
#pragma unroll
for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) {
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
for (int i = 0; i < pack_size; ++i) {
const int col_id = pack_id * pack_size + i;
row_dy_buf[col_id] =
(cols * row_dy_buf[col_id] - warp_sum_stats1[row_id] -
row_normalized_buf[col_id] * warp_sum_stats2[row_id]) *
inv_variance_over_cols;
}
store.template store<pack_size>(row_dy_buf + pack_id * pack_size, global_row_id,
col);
}
}
}
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int cols_per_thread, int thread_group_width, int rows_per_access,
bool padding>
inline cudaError_t LaunchLayerNormGradWarpImpl(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
const int64_t num_blocks =
(rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(
LayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access, padding>,
block_size, 0, num_blocks, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, padding>
<<<grid_dim_x, block_dim, 0, stream>>>(load_x, load_scaled_dy, store, mean, inv_variance,
rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int cols_per_thread, int thread_group_width, int rows_per_access>
inline cudaError_t DispatchLayerNormGradWarpImplPadding(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
if (cols == cols_per_thread * thread_group_width) {
return LaunchLayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access,
false>(stream, load_x, load_scaled_dy, store, mean,
inv_variance, rows, cols);
} else {
return LaunchLayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access,
true>(stream, load_x, load_scaled_dy, store, mean,
inv_variance, rows, cols);
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchLayerNormGradWarpImplCols(
cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, \
ComputeType, pack_size, pack_size, \
thread_group_width, 2>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} else { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, \
ComputeType, pack_size, pack_size, \
thread_group_width, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, col, kWarpSize, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
}
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(3)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(5)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(7)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(9)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(11)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(13)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(15)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(17)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(19)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(21)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(23)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(25)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(27)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(29)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(31)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchLayerNormGradWarpImplCols(
cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, \
ComputeType, pack_size, pack_size, \
thread_group_width, 2>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} else { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, \
ComputeType, pack_size, pack_size, \
thread_group_width, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchLayerNormGradWarpImplPadding<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, \
pack_size, col, kWarpSize, 1>( \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
struct DispatchLayerNormGradWarpImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
if (cols % 2 == 0) {
return DispatchLayerNormGradWarpImplCols<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, 2>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
} else {
return DispatchLayerNormGradWarpImplCols<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, 1>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
}
};
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormGradWarpImpl(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
return DispatchLayerNormGradWarpImplPackSize<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>()(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int block_size>
__global__ void LayerNormGradBlockSMemImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean,
const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
extern __shared__ __align__(sizeof(double)) unsigned char grad_shared_buf[];
auto* normalized_buf = reinterpret_cast<ComputeType*>(grad_shared_buf);
auto* dy_buf = normalized_buf + cols;
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;
const ComputeType one_over_cols =
static_cast<ComputeType>(1.0) / static_cast<ComputeType>(cols);
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType sum_stats1 = 0;
ComputeType sum_stats2 = 0;
const ComputeType mean_val = mean[row];
const ComputeType inv_variance_val = inv_variance[row];
const ComputeType inv_variance_over_cols = inv_variance_val * one_over_cols;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType x_pack[pack_size];
ComputeType dy_pack[pack_size];
load_x.template load<pack_size>(x_pack, row, pack_id * pack_size);
load_scaled_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
const int buf_offset = i * num_packs + pack_id;
ComputeType normalized = (x_pack[i] - mean_val) * inv_variance_val;
normalized_buf[buf_offset] = normalized;
dy_buf[buf_offset] = dy_pack[i];
sum_stats1 += dy_pack[i];
sum_stats2 += dy_pack[i] * normalized;
}
}
const ComputeType row_sum_stats1 =
BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats1);
const ComputeType row_sum_stats2 =
BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats2);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
const int buf_offset = i * num_packs + pack_id;
pack[i] = (cols * dy_buf[buf_offset] - row_sum_stats1 -
normalized_buf[buf_offset] * row_sum_stats2) *
inv_variance_over_cols;
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int block_size>
inline cudaError_t LaunchLayerNormGradBlockSMemImpl(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance, int smem,
const int64_t rows, const int64_t cols) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err =
GetNumBlocks(LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size>,
block_size, smem, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size, block_size>
<<<grid_dim_x, block_size, smem, stream>>>(load_x, load_scaled_dy, store, mean,
inv_variance, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
inline cudaError_t TryDispatchLayerNormGradBlockSMemImplBlockSize(
cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean, const ComputeType* inv_variance, const int64_t rows,
const int64_t cols, bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem = cols * sizeof(ComputeType) * 2;
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_1,
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_1>,
block_size_conf_1, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_4,
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_4>,
block_size_conf_4, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {
*success = true;
return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size_conf_4>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_3,
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_3>,
block_size_conf_3, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {
*success = true;
return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size_conf_3>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_2,
LayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_2>,
block_size_conf_2, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
*success = true;
return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size_conf_2>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);
}
*success = true;
return LaunchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size_conf_1>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, smem, rows, cols);
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
struct TryDispatchLayerNormGradBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols, bool* success) {
if (cols % 2 == 0) {
return TryDispatchLayerNormGradBlockSMemImplBlockSize<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType, 2>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success);
} else {
return TryDispatchLayerNormGradBlockSMemImplBlockSize<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType, 1>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success);
}
}
};
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline cudaError_t TryDispatchLayerNormGradBlockSMemImpl(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance,
const int64_t rows, const int64_t cols,
bool* success) {
return TryDispatchLayerNormGradBlockSMemImplPackSize<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType>()(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols, success);
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size, int block_size>
__global__ void LayerNormGradBlockUncachedImpl(LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean,
const ComputeType* inv_variance, const int64_t rows,
const int64_t cols) {
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = static_cast<int>(cols) / pack_size;
const ComputeType one_over_cols =
static_cast<ComputeType>(1.0) / static_cast<ComputeType>(cols);
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
const ComputeType mean_val = mean[row];
const ComputeType inv_variance_val = inv_variance[row];
const ComputeType inv_variance_over_cols = inv_variance_val * one_over_cols;
ComputeType sum_stats1 = 0;
ComputeType sum_stats2 = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType x_pack[pack_size];
ComputeType dy_pack[pack_size];
load_x.template load<pack_size>(x_pack, row, pack_id * pack_size);
load_scaled_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
sum_stats1 += dy_pack[i];
sum_stats2 += dy_pack[i] * (x_pack[i] - mean_val) * inv_variance_val;
}
}
const ComputeType row_sum_stats1 =
BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats1);
const ComputeType row_sum_stats2 =
BlockAllReduce<SumOp, ComputeType, block_size>(sum_stats2);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType x_pack[pack_size];
ComputeType dy_pack[pack_size];
load_x.template load<pack_size>(x_pack, row, pack_id * pack_size);
load_scaled_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
dy_pack[i] = (cols * dy_pack[i] - row_sum_stats1 -
(x_pack[i] - mean_val) * inv_variance_val * row_sum_stats2) *
inv_variance_over_cols;
}
store.template store<pack_size>(dy_pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType,
int pack_size>
inline cudaError_t LaunchLayerNormGradBlockUncachedImpl(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy, STORE store,
const ComputeType* mean,
const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err =
GetNumBlocks(LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
pack_size, block_size>,
block_size, 0, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
LayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType, pack_size,
block_size><<<grid_dim_x, block_size, 0, stream>>>(
load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
struct DispatchLayerNormGradBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
if (cols % 2 == 0 && cols > kWarpSize) {
return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
2>(stream, load_x, load_scaled_dy, store,
mean, inv_variance, rows, cols);
} else {
return LaunchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType,
1>(stream, load_x, load_scaled_dy, store,
mean, inv_variance, rows, cols);
}
}
};
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline cudaError_t DispatchLayerNormGradBlockUncachedImpl(cudaStream_t stream, LOAD_X load_x,
LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean,
const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
return DispatchLayerNormGradBlockUncachedImplPackSize<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType>()(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNormGrad(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
if (cols <= 1024) {
return DispatchLayerNormGradWarpImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err =
TryDispatchLayerNormGradBlockSMemImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols,
&dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE,
ComputeType>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
return cudaSuccess;
}
}
template <typename LOAD_X, typename LOAD_SCALED_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLayerNormGrad(cudaStream_t stream, LOAD_X load_x, LOAD_SCALED_DY load_scaled_dy,
STORE store, const ComputeType* mean, const ComputeType* inv_variance,
const int64_t rows, const int64_t cols) {
return DispatchLayerNormGradBlockUncachedImpl<LOAD_X, LOAD_SCALED_DY, STORE, ComputeType>(
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols);
}
} // namespace layer_norm
} // namespace fastfold
#endif // FASTFOLD_LAYER_NORM_H_
\ No newline at end of file
#include <torch/extension.h>
#include <cassert>
#include <vector>
#include "compat.h"
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int& n1, int& n2) {
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert(input.sizes()[i + idiff] == normalized_shape[i]);
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta) {
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int& n1, int& n2) {
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape=" << normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape << ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input, normalized_shape, n1, n2);
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, at::Tensor gamma,
at::Tensor beta, int& n1, int& n2) {
check_args(input, normalized_shape, n1, n2);
check_args(normalized_shape, gamma, beta);
}
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, at::Tensor* input,
int n1, int n2, at::IntArrayRef normalized_shape, at::Tensor* gamma,
at::Tensor* beta, double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm_affine(at::Tensor input, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty({n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape, &gamma, &beta,
epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar,
at::Tensor* input, int n1, int n2, at::IntArrayRef normalized_shape,
at::Tensor* gamma, at::Tensor* beta, double epsilon,
at::Tensor* grad_input, at::Tensor* grad_gamma,
at::Tensor* grad_beta);
std::vector<at::Tensor> layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean,
at::Tensor invvar, at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, normalized_shape, &gamma, &beta,
epsilon, &grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");
}
\ No newline at end of file
#include <cooperative_groups.h>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#include "layer_norm.cuh"
#include "type_shim.h"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, at::Tensor* input,
int n1, int n2, at::IntArrayRef normalized_shape, at::Tensor* gamma,
at::Tensor* beta, double epsilon) {
at::Tensor normalized = at::empty_like(*output);
fastfold::layer_norm::DirectLoad<at::BFloat16, float> load((at::BFloat16*)input->data_ptr(),
n2);
fastfold::layer_norm::AffineStore<float, at::BFloat16, true, true> store(
(at::BFloat16*)normalized.data_ptr(), (at::BFloat16*)output->data_ptr(), n2,
(at::BFloat16*)gamma->data_ptr(), (at::BFloat16*)beta->data_ptr());
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::layer_norm::DispatchLayerNorm<decltype(load), decltype(store), float>(
cuda_stream, load, store, n1, n2, epsilon, (float*)mean->data_ptr(),
(float*)invvar->data_ptr());
}
template <typename T>
struct SharedMemory;
template <>
struct SharedMemory<float> {
__device__ float* getPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
template <typename T, typename U, typename V>
__device__ void cuLoadWriteStridedInputs(const int i1_block, const int thr_load_row_off,
const int thr_load_col_off, const int i2_off,
const int row_stride, U* warp_buf1, U* warp_buf2,
const T* input, const V* dout, const int i1_end,
const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template <typename T, typename U, typename V>
__device__ void cuLoadAddStridedInputs(const int i1_block, const int thr_load_row_off,
const int thr_load_col_off, const int i2_off,
const int row_stride, U* warp_buf1, U* warp_buf2,
const T* input, const V* dout, const int i1_end,
const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template <typename T, typename U, typename V>
__global__ void cuComputePartGradGammaBeta(const V* __restrict__ dout, const T* __restrict__ input,
const int n1, const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar, U epsilon,
U* part_grad_gamma, U* part_grad_beta) {
const int numsegs_n1 = (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
const int i1_beg_plus_one = (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x + 1;
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
const int thr_load_row_off = (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y +
// (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar);
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
i1_block += blockDim.y * blockDim.y) {
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
warp_buf1, warp_buf2, input, dout, i1_end, n2, mean, invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k * blockDim.y;
int idx1 = row1 * row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template <typename U, typename V>
__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma, const U* part_grad_beta,
const int part_size, const int n1, const int n2,
V* grad_gamma, V* grad_beta) {
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr =
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
sum_beta += part_grad_beta_ptr[warp_offset * n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx + nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx + nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template <typename T, typename U, typename V>
void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, at::Tensor* input, int n1,
int n2, const V* gamma, const V* beta, double epsilon, T* grad_input,
V* grad_gamma, V* grad_beta) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32, 4, 1);
const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma =
at::empty({part_size, n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon),
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>());
const dim3 threads3(32, 8, 1);
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>(), part_size, n1, n2,
grad_gamma, grad_beta);
}
}
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar,
at::Tensor* input, int n1, int n2, at::IntArrayRef normalized_shape,
at::Tensor* gamma, at::Tensor* beta, double epsilon,
at::Tensor* grad_input, at::Tensor* grad_gamma,
at::Tensor* grad_beta) {
at::Tensor add_to_output = at::empty_like(*grad_input);
fastfold::layer_norm::DirectLoad<at::BFloat16, float> load_x((at::BFloat16*)input->data_ptr(),
n2);
fastfold::layer_norm::ScaleLoad<at::BFloat16, float, true> load_scaled_dy(
(at::BFloat16*)dout->data_ptr(), (at::BFloat16*)gamma->data_ptr(), n2);
fastfold::layer_norm::AddStore<float, at::BFloat16, true> store(
(at::BFloat16*)add_to_output.data_ptr(), (at::BFloat16*)grad_input->data_ptr(), n2);
auto cuda_stream = at::cuda::getCurrentCUDAStream().stream();
fastfold::layer_norm::DispatchLayerNormGrad<decltype(load_x), decltype(load_scaled_dy),
decltype(store), float>(
cuda_stream, load_x, load_scaled_dy, store, (float*)mean->data_ptr(),
(float*)invvar->data_ptr(), n1, n2);
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma->scalar_type(), "cuda_layer_norm_gradient_kernel",
HostLayerNormGradient(dout->DATA_PTR<scalar_t_out>(), mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(), input, n1, n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL, epsilon,
grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);)
}
\ No newline at end of file
#ifndef FASTFOLD_SOFTMAX_H_
#define FASTFOLD_SOFTMAX_H_
#include <assert.h>
#include <cuda.h>
#include <math_constants.h>
#include <cub/cub.cuh>
#include "ATen/ATen.h"
#if CUDA_VERSION >= 11000
#include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000
namespace fastfold {
namespace softmax {
constexpr int kWarpSize = 32;
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const { return a + b; }
};
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); }
};
template <template <typename> class ReductionOp, typename T, int thread_group_width = kWarpSize>
__inline__ __device__ T WarpAllReduce(T val) {
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}
template <template <typename> class ReductionOp, typename T, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp<T>());
if (threadIdx.x == 0) {
result_broadcast = result;
}
__syncthreads();
return result_broadcast;
}
template <typename T>
__inline__ __device__ T Inf();
template <>
__inline__ __device__ at::BFloat16 Inf<at::BFloat16>() {
return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
}
template <>
__inline__ __device__ float Inf<float>() {
return CUDART_INF_F;
}
template <>
__inline__ __device__ double Inf<double>() {
return CUDART_INF;
}
template <typename T>
__inline__ __device__ T Exp(T x);
template <>
__inline__ __device__ at::BFloat16 Exp<at::BFloat16>(at::BFloat16 x) {
return exp(x);
}
template <>
__inline__ __device__ float Exp<float>(float x) {
return __expf(x);
}
template <>
__inline__ __device__ double Exp<double>(double x) {
return exp(x);
}
template <typename T>
__inline__ __device__ T Div(T a, T b);
template <>
__inline__ __device__ at::BFloat16 Div<at::BFloat16>(at::BFloat16 a, at::BFloat16 b) {
return a / b;
}
template <>
__inline__ __device__ float Div<float>(float a, float b) {
return __fdividef(a, b);
}
template <>
__inline__ __device__ double Div<double>(double a, double b) {
return a / b;
}
template <typename T>
__inline__ __device__ T Log(T x);
template <>
__inline__ __device__ at::BFloat16 Log<at::BFloat16>(at::BFloat16 x) {
return log(x);
}
template <>
__inline__ __device__ float Log<float>(float x) {
return __logf(x);
}
template <>
__inline__ __device__ double Log<double>(double x) {
return log(x);
}
inline cudaError_t GetNumBlocks(int64_t block_size, int64_t max_blocks, int64_t waves,
int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
}
int sm_count;
{
cudaError_t err = cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
}
int tpm;
{
cudaError_t err = cudaDeviceGetAttribute(&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
if (err != cudaSuccess) {
return err;
}
}
*num_blocks =
std::max<int>(1, std::min<int64_t>(max_blocks, sm_count * tpm / block_size * waves));
return cudaSuccess;
}
template <typename T>
struct DefaultComputeType {
using type = T;
};
template <>
struct DefaultComputeType<half> {
using type = float;
};
#if CUDA_VERSION >= 11000
template <>
struct DefaultComputeType<nv_bfloat16> {
using type = float;
};
#endif // CUDA_VERSION >= 11000
template <typename T, int N>
struct GetPackType {
using type = typename std::aligned_storage<N * sizeof(T), N * sizeof(T)>::type;
};
template <typename T, int N>
using PackType = typename GetPackType<T, N>::type;
template <typename T, int N>
union Pack {
static_assert(sizeof(PackType<T, N>) == sizeof(T) * N, "");
__device__ Pack() {
// do nothing
}
PackType<T, N> storage;
T elem[N];
};
template <typename SRC, typename DST>
struct DirectLoad {
DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {}
template <int N>
__device__ void load(DST* dst, int64_t row, int64_t col) const {
Pack<SRC, N> pack;
const int64_t offset = (row * row_size + col) / N;
pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(src) + offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
dst[i] = static_cast<DST>(pack.elem[i]);
}
}
const SRC* src;
int64_t row_size;
};
template <typename SRC, typename DST>
struct DirectStore {
DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
Pack<DST, N> pack;
const int64_t offset = (row * row_size + col) / N;
#pragma unroll
for (int i = 0; i < N; ++i) {
pack.elem[i] = static_cast<DST>(src[i]);
}
*(reinterpret_cast<PackType<DST, N>*>(dst) + offset) = pack.storage;
}
DST* dst;
int64_t row_size;
};
template <typename SRC, typename DST>
struct ScaleMaskLoad {
ScaleMaskLoad(const SRC* src, const SRC* mask, int64_t row_size, int64_t head, SRC scale)
: src(src), mask(mask), row_size(row_size), head(head), scale(scale) {}
template <int N>
__device__ void load(DST* dst, int64_t row, int64_t col) {
softmax::Pack<SRC, N> pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t mask_offset = ((row / (head * row_size)) * row_size + col) / N;
pack.storage = *(reinterpret_cast<const softmax::PackType<SRC, N>*>(src) + offset);
softmax::Pack<SRC, N> mask_pack;
mask_pack.storage =
*(reinterpret_cast<const softmax::PackType<SRC, N>*>(mask) + mask_offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
if (mask_pack.elem[i] == 0) {
dst[i] = static_cast<DST>(c10::BFloat16(0x7F7F, c10::BFloat16::from_bits()));
} else {
dst[i] = static_cast<DST>(pack.elem[i]) * static_cast<DST>(scale);
}
}
}
const SRC* src;
const SRC* mask;
int64_t row_size;
int64_t head;
SRC fill;
SRC scale;
};
template <typename SRC, typename DST>
struct ScaleMaskStore {
ScaleMaskStore(DST* dst, const DST* mask, int64_t row_size, int64_t head, DST scale)
: dst(dst), mask(mask), row_size(row_size), head(head), scale(scale) {}
template <int N>
__device__ void store(const SRC* src, int64_t row, int64_t col) {
softmax::Pack<DST, N> pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t mask_offset = ((row / (head * row_size)) * row_size + col) / N;
softmax::Pack<DST, N> mask_pack;
mask_pack.storage =
*(reinterpret_cast<const softmax::PackType<DST, N>*>(mask) + mask_offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
if (mask_pack.elem[i] == 0) {
pack.elem[i] = c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
} else {
pack.elem[i] = static_cast<DST>(src[i]) * static_cast<DST>(scale);
}
}
*(reinterpret_cast<softmax::PackType<DST, N>*>(dst) + offset) = pack.storage;
}
DST* dst;
const DST* mask;
int64_t row_size;
int64_t head;
DST fill;
DST scale;
};
template <typename SRC, typename DST>
struct ScaleMaskBiasLoad {
ScaleMaskBiasLoad(const SRC* src, const SRC* mask, const SRC* bias, int64_t row_size,
int64_t head, SRC scale)
: src(src), mask(mask), bias(bias), row_size(row_size), head(head), scale(scale) {}
template <int N>
__device__ void load(DST* dst, int64_t row, int64_t col) {
softmax::Pack<SRC, N> pack;
const int64_t offset = (row * row_size + col) / N;
const int64_t mask_offset = ((row / (head * row_size)) * row_size + col) / N;
const int64_t bias_offset = ((row % (head * row_size)) * row_size + col) / N;
pack.storage = *(reinterpret_cast<const softmax::PackType<SRC, N>*>(src) + offset);
softmax::Pack<SRC, N> mask_pack;
softmax::Pack<SRC, N> bias_pack;
mask_pack.storage =
*(reinterpret_cast<const softmax::PackType<SRC, N>*>(mask) + mask_offset);
bias_pack.storage =
*(reinterpret_cast<const softmax::PackType<SRC, N>*>(bias) + bias_offset);
#pragma unroll
for (int i = 0; i < N; ++i) {
if (mask_pack.elem[i] == 0) {
dst[i] = static_cast<DST>(c10::BFloat16(0x7F7F, c10::BFloat16::from_bits()));
} else {
dst[i] = static_cast<DST>(pack.elem[i]) * static_cast<DST>(scale);
dst[i] += static_cast<DST>(bias_pack.elem[i]);
}
}
}
const SRC* src;
const SRC* mask;
const SRC* bias;
int64_t row_size;
int64_t head;
SRC fill;
SRC scale;
};
enum class Algorithm {
kSoftmax = 0,
kLogSoftmax = 1,
};
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, bool padding, Algorithm algorithm>
__global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols) {
static_assert(cols_per_thread % pack_size == 0, "");
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
constexpr int num_packs = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * thread_group_width);
ComputeType buf[rows_per_access][cols_per_thread];
const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int num_global_thread_group = gridDim.x * blockDim.y;
const int lane_id = threadIdx.x;
const int64_t step = num_global_thread_group * rows_per_access;
for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {
ComputeType thread_max[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
thread_max[row_id] = -Inf<ComputeType>();
ComputeType* row_buf = buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < num_packs; ++pack_id) {
const int pack_offset = pack_id * pack_size;
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
load.template load<pack_size>(row_buf + pack_offset, row + row_id, col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_max[row_id] = max(thread_max[row_id], row_buf[pack_offset + i]);
}
} else {
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
row_buf[pack_offset + i] = -Inf<ComputeType>();
}
}
}
}
ComputeType warp_max[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
warp_max[row_id] =
WarpAllReduce<MaxOp, ComputeType, thread_group_width>(thread_max[row_id]);
}
ComputeType thread_sum[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
thread_sum[row_id] = 0;
ComputeType* row_buf = buf[row_id];
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
if (algorithm == Algorithm::kSoftmax) {
row_buf[i] = Exp(row_buf[i] - warp_max[row_id]);
thread_sum[row_id] += row_buf[i];
} else if (algorithm == Algorithm::kLogSoftmax) {
row_buf[i] -= warp_max[row_id];
thread_sum[row_id] += Exp(row_buf[i]);
} else {
__trap();
}
}
}
ComputeType warp_sum[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
warp_sum[row_id] =
WarpAllReduce<SumOp, ComputeType, thread_group_width>(thread_sum[row_id]);
}
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
ComputeType* row_buf = buf[row_id];
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
if (algorithm == Algorithm::kSoftmax) {
row_buf[i] = Div(row_buf[i], warp_sum[row_id]);
} else if (algorithm == Algorithm::kLogSoftmax) {
row_buf[i] -= Log(warp_sum[row_id]);
} else {
__trap();
}
}
#pragma unroll
for (int i = 0; i < num_packs; ++i) {
const int col = (i * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
store.template store<pack_size>(row_buf + i * pack_size, row + row_id, col);
}
}
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, bool padding, Algorithm algorithm>
inline cudaError_t LaunchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
const int64_t num_blocks =
(rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
SoftmaxWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread, thread_group_width,
rows_per_access, padding, algorithm>
<<<grid_dim_x, block_dim, 0, stream>>>(load, store, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int cols_per_thread,
int thread_group_width, int rows_per_access, Algorithm algorithm>
inline cudaError_t DispatchSoftmaxWarpImplPadding(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
if (cols == cols_per_thread * thread_group_width) {
return LaunchSoftmaxWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, false, algorithm>(
stream, load, store, rows, cols);
} else {
return LaunchSoftmaxWarpImpl<LOAD, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, true, algorithm>(
stream, load, store, rows, cols);
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>
typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchSoftmaxWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 2, algorithm>( \
stream, load, store, rows, cols); \
} else { \
return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 1, algorithm>( \
stream, load, store, rows, cols); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, kWarpSize, \
1, algorithm>(stream, load, store, rows, cols); \
}
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(3)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(5)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(7)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(9)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(11)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(13)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(15)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(17)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(19)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(21)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(23)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(25)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(27)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(29)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(31)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>
typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchSoftmaxWarpImplCols(
cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 2, algorithm>( \
stream, load, store, rows, cols); \
} else { \
return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, pack_size, \
thread_group_width, 1, algorithm>( \
stream, load, store, rows, cols); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchSoftmaxWarpImplPadding<LOAD, STORE, ComputeType, pack_size, col, kWarpSize, \
1, algorithm>(stream, load, store, rows, cols); \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
struct DispatchSoftmaxWarpImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols % 2 == 0) {
return DispatchSoftmaxWarpImplCols<LOAD, STORE, ComputeType, 2, algorithm>(
stream, load, store, rows, cols);
} else {
return DispatchSoftmaxWarpImplCols<LOAD, STORE, ComputeType, 1, algorithm>(
stream, load, store, rows, cols);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
inline cudaError_t DispatchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxWarpImplPackSize<LOAD, STORE, ComputeType, algorithm>()(
stream, load, store, rows, cols);
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,
Algorithm algorithm>
__global__ void SoftmaxBlockSMemImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[];
auto* buf = reinterpret_cast<ComputeType*>(shared_buf);
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = cols / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_max = -Inf<ComputeType>();
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
buf[i * num_packs + pack_id] = pack[i];
thread_max = max(thread_max, pack[i]);
}
}
const ComputeType row_max = BlockAllReduce<MaxOp, ComputeType, block_size>(thread_max);
ComputeType thread_sum = 0;
for (int col = tid; col < cols; col += block_size) {
if (algorithm == Algorithm::kSoftmax) {
const ComputeType exp_x = Exp(buf[col] - row_max);
buf[col] = exp_x;
thread_sum += exp_x;
} else {
const ComputeType x = buf[col] - row_max;
buf[col] = x;
thread_sum += Exp(x);
}
}
const ComputeType row_sum = BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
pack[i] = Div(buf[i * num_packs + pack_id], row_sum);
} else if (algorithm == Algorithm::kLogSoftmax) {
pack[i] = buf[i * num_packs + pack_id] - Log(row_sum);
} else {
__trap();
}
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,
Algorithm algorithm>
inline cudaError_t LaunchSoftmaxBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store, int smem,
const int64_t rows, const int64_t cols) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size, algorithm>
<<<grid_dim_x, block_size, smem, stream>>>(load, store, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>
inline cudaError_t TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream, LOAD load,
STORE store, const int64_t rows,
const int64_t cols, bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem = cols * sizeof(ComputeType);
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_1,
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1, algorithm>,
block_size_conf_1, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_4,
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4, algorithm>,
block_size_conf_4, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {
*success = true;
return LaunchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_4,
algorithm>(stream, load, store, smem, rows, cols);
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_3,
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3, algorithm>,
block_size_conf_3, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {
*success = true;
return LaunchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_3,
algorithm>(stream, load, store, smem, rows, cols);
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_2,
SoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2, algorithm>,
block_size_conf_2, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
*success = true;
return LaunchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_2,
algorithm>(stream, load, store, smem, rows, cols);
}
*success = true;
return LaunchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, pack_size, block_size_conf_1,
algorithm>(stream, load, store, smem, rows, cols);
}
template <typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
struct TryDispatchSoftmaxBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols, bool* success) {
if (cols % 2 == 0) {
return TryDispatchSoftmaxBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 2, algorithm>(
stream, load, store, rows, cols, success);
} else {
return TryDispatchSoftmaxBlockSMemImplBlockSize<LOAD, STORE, ComputeType, 1, algorithm>(
stream, load, store, rows, cols, success);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
inline cudaError_t TryDispatchSoftmaxBlockSMemImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols,
bool* success) {
return TryDispatchSoftmaxBlockSMemImplPackSize<LOAD, STORE, ComputeType, algorithm>()(
stream, load, store, rows, cols, success);
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size,
Algorithm algorithm>
__global__ void SoftmaxBlockUncachedImpl(LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = cols / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_max = -Inf<ComputeType>();
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_max = max(thread_max, pack[i]);
}
}
const ComputeType row_max = BlockAllReduce<MaxOp, ComputeType, block_size>(thread_max);
ComputeType thread_sum = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_sum += Exp(pack[i] - row_max);
}
}
const ComputeType row_sum = BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
load.template load<pack_size>(pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
pack[i] = Div(Exp(pack[i] - row_max), row_sum);
} else if (algorithm == Algorithm::kLogSoftmax) {
pack[i] = (pack[i] - row_max) - Log(row_sum);
} else {
__trap();
}
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD, typename STORE, typename ComputeType, int pack_size, Algorithm algorithm>
inline cudaError_t LaunchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
SoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, pack_size, block_size, algorithm>
<<<grid_dim_x, block_size, 0, stream>>>(load, store, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
struct DispatchSoftmaxBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols % 2 == 0) {
return LaunchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, 2, algorithm>(
stream, load, store, rows, cols);
} else {
return LaunchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, 1, algorithm>(
stream, load, store, rows, cols);
}
}
};
template <typename LOAD, typename STORE, typename ComputeType, Algorithm algorithm>
inline cudaError_t DispatchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD load, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxBlockUncachedImplPackSize<LOAD, STORE, ComputeType, algorithm>()(
stream, load, store, rows, cols);
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols <= 1024) {
return DispatchSoftmaxWarpImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load, store, rows, cols);
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err =
TryDispatchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load, store, rows, cols, &dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load, store, rows, cols);
}
return cudaSuccess;
}
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kSoftmax>(
stream, load, store, rows, cols);
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
if (cols <= 1024) {
return DispatchSoftmaxWarpImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(
stream, load, store, rows, cols);
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err =
TryDispatchSoftmaxBlockSMemImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(
stream, load, store, rows, cols, &dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType,
Algorithm::kLogSoftmax>(stream, load, store,
rows, cols);
}
return cudaSuccess;
}
}
template <typename LOAD, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows,
const int64_t cols) {
return DispatchSoftmaxBlockUncachedImpl<LOAD, STORE, ComputeType, Algorithm::kLogSoftmax>(
stream, load, store, rows, cols);
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int cols_per_thread, int thread_group_width, int rows_per_access, bool padding,
Algorithm algorithm>
__global__ void SoftmaxGradWarpImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,
const int64_t cols) {
static_assert(cols_per_thread % pack_size == 0, "");
constexpr int pack_per_thread = cols_per_thread / pack_size;
assert(cols <= cols_per_thread * thread_group_width);
static_assert(thread_group_width <= kWarpSize, "");
static_assert(kWarpSize % thread_group_width == 0, "");
ComputeType y_buf[rows_per_access][cols_per_thread];
ComputeType dy_buf[rows_per_access][cols_per_thread];
const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;
const int num_global_thread_group = gridDim.x * blockDim.y;
const int lane_id = threadIdx.x;
const int64_t step = num_global_thread_group * rows_per_access;
for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {
ComputeType thread_sum[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
thread_sum[row_id] = 0;
ComputeType* row_y_buf = y_buf[row_id];
ComputeType* row_dy_buf = dy_buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) {
const int pack_offset = pack_id * pack_size;
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
load_y.template load<pack_size>(row_y_buf + pack_offset, row + row_id, col);
load_dy.template load<pack_size>(row_dy_buf + pack_offset, row + row_id, col);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
thread_sum[row_id] +=
row_y_buf[pack_offset + i] * row_dy_buf[pack_offset + i];
} else if (algorithm == Algorithm::kLogSoftmax) {
thread_sum[row_id] += row_dy_buf[pack_offset + i];
} else {
__trap();
}
}
}
}
}
ComputeType warp_sum[rows_per_access];
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
warp_sum[row_id] =
WarpAllReduce<SumOp, ComputeType, thread_group_width>(thread_sum[row_id]);
}
#pragma unroll
for (int row_id = 0; row_id < rows_per_access; ++row_id) {
ComputeType* row_y_buf = y_buf[row_id];
ComputeType* row_dy_buf = dy_buf[row_id];
#pragma unroll
for (int pack_id = 0; pack_id < pack_per_thread; ++pack_id) {
const int pack_offset = pack_id * pack_size;
const int col = (pack_id * thread_group_width + lane_id) * pack_size;
if (!padding || col < cols) {
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
row_dy_buf[pack_offset + i] =
(row_dy_buf[pack_offset + i] - warp_sum[row_id]) *
row_y_buf[pack_offset + i];
} else if (algorithm == Algorithm::kLogSoftmax) {
row_dy_buf[pack_offset + i] -=
Exp(row_y_buf[pack_offset + i]) * warp_sum[row_id];
} else {
__trap();
}
}
store.template store<pack_size>(row_dy_buf + pack_offset, row + row_id, col);
}
}
}
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int cols_per_thread, int thread_group_width, int rows_per_access, bool padding,
Algorithm algorithm>
inline cudaError_t LaunchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy,
STORE store, const int64_t rows, const int64_t cols) {
constexpr int block_size = 128;
constexpr int waves = 32;
static_assert(block_size % thread_group_width == 0, "");
constexpr int thread_groups_per_block = block_size / thread_group_width;
dim3 block_dim(thread_group_width, thread_groups_per_block);
const int64_t num_blocks =
(rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
SoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, cols_per_thread,
thread_group_width, rows_per_access, padding, algorithm>
<<<grid_dim_x, block_dim, 0, stream>>>(load_y, load_dy, store, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int cols_per_thread, int thread_group_width, int rows_per_access, Algorithm algorithm>
inline cudaError_t DispatchSoftmaxGradWarpImplPadding(cudaStream_t stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols == cols_per_thread * thread_group_width) {
return LaunchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access,
false, algorithm>(stream, load_y, load_dy, store, rows,
cols);
} else {
return LaunchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
cols_per_thread, thread_group_width, rows_per_access, true,
algorithm>(stream, load_y, load_dy, store, rows, cols);
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
Algorithm algorithm>
typename std::enable_if<pack_size == 1, cudaError_t>::type DispatchSoftmaxGradWarpImplCols(
cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,
const int64_t cols) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 2, \
algorithm>(stream, load_y, load_dy, store, \
rows, cols); \
} else { \
return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 1, \
algorithm>(stream, load_y, load_dy, store, \
rows, cols); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, \
col, kWarpSize, 1, algorithm>( \
stream, load_y, load_dy, store, rows, cols); \
}
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(3)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(5)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(7)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(9)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(11)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(13)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(15)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(17)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(19)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(21)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(23)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(25)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(27)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(29)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(31)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
Algorithm algorithm>
typename std::enable_if<pack_size == 2, cudaError_t>::type DispatchSoftmaxGradWarpImplCols(
cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, const int64_t rows,
const int64_t cols) {
if (cols <= 0) {
return cudaErrorInvalidValue;
}
#define DEFINE_ONE_ELIF(thread_group_width) \
else if (cols <= (thread_group_width)*pack_size) { \
if (rows % 2 == 0) { \
return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 2, \
algorithm>(stream, load_y, load_dy, store, \
rows, cols); \
} else { \
return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, \
pack_size, pack_size, thread_group_width, 1, \
algorithm>(stream, load_y, load_dy, store, \
rows, cols); \
} \
}
DEFINE_ONE_ELIF(1)
DEFINE_ONE_ELIF(2)
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col) \
else if (cols <= (col)*kWarpSize) { \
return DispatchSoftmaxGradWarpImplPadding<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, \
col, kWarpSize, 1, algorithm>( \
stream, load_y, load_dy, store, rows, cols); \
}
DEFINE_ONE_ELIF(4)
DEFINE_ONE_ELIF(6)
DEFINE_ONE_ELIF(8)
DEFINE_ONE_ELIF(10)
DEFINE_ONE_ELIF(12)
DEFINE_ONE_ELIF(14)
DEFINE_ONE_ELIF(16)
DEFINE_ONE_ELIF(18)
DEFINE_ONE_ELIF(20)
DEFINE_ONE_ELIF(22)
DEFINE_ONE_ELIF(24)
DEFINE_ONE_ELIF(26)
DEFINE_ONE_ELIF(28)
DEFINE_ONE_ELIF(30)
DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
else {
return cudaErrorInvalidValue;
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
struct DispatchSoftmaxGradWarpImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols % 2 == 0) {
return DispatchSoftmaxGradWarpImplCols<LOAD_Y, LOAD_DY, STORE, ComputeType, 2,
algorithm>(stream, load_y, load_dy, store, rows,
cols);
} else {
return DispatchSoftmaxGradWarpImplCols<LOAD_Y, LOAD_DY, STORE, ComputeType, 1,
algorithm>(stream, load_y, load_dy, store, rows,
cols);
}
}
};
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
inline cudaError_t DispatchSoftmaxGradWarpImpl(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy,
STORE store, const int64_t rows,
const int64_t cols) {
return DispatchSoftmaxGradWarpImplPackSize<LOAD_Y, LOAD_DY, STORE, ComputeType, algorithm>()(
stream, load_y, load_dy, store, rows, cols);
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int block_size, Algorithm algorithm>
__global__ void SoftmaxGradBlockSMemImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
extern __shared__ __align__(sizeof(double)) unsigned char grad_shared_buf[];
auto* y_buf = reinterpret_cast<ComputeType*>(grad_shared_buf);
auto* dy_buf = y_buf + cols;
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = cols / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_sum = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType y_pack[pack_size];
ComputeType dy_pack[pack_size];
load_y.template load<pack_size>(y_pack, row, pack_id * pack_size);
load_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
y_buf[i * num_packs + pack_id] = y_pack[i];
dy_buf[i * num_packs + pack_id] = dy_pack[i];
if (algorithm == Algorithm::kSoftmax) {
thread_sum += y_pack[i] * dy_pack[i];
} else if (algorithm == Algorithm::kLogSoftmax) {
thread_sum += dy_pack[i];
} else {
__trap();
}
}
}
const ComputeType row_sum = BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
pack[i] = (dy_buf[i * num_packs + pack_id] - row_sum) *
y_buf[i * num_packs + pack_id];
} else if (algorithm == Algorithm::kLogSoftmax) {
pack[i] = dy_buf[i * num_packs + pack_id] -
Exp(y_buf[i * num_packs + pack_id]) * row_sum;
} else {
__trap();
}
}
store.template store<pack_size>(pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int block_size, Algorithm algorithm>
inline cudaError_t LaunchSoftmaxGradBlockSMemImpl(cudaStream_t stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store, int smem,
const int64_t rows, const int64_t cols) {
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size, algorithm>
<<<grid_dim_x, block_size, smem, stream>>>(load_y, load_dy, store, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
Algorithm algorithm>
inline cudaError_t TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows,
const int64_t cols, bool* success) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem = cols * sizeof(ComputeType) * 2;
int max_active_blocks_conf_1;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_1,
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_1, algorithm>,
block_size_conf_1, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_1 <= 0) {
*success = false;
return cudaSuccess;
}
int max_active_blocks_conf_4;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_4,
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_4, algorithm>,
block_size_conf_4, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_4 == max_active_blocks_conf_1) {
*success = true;
return LaunchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_4, algorithm>(
stream, load_y, load_dy, store, smem, rows, cols);
}
int max_active_blocks_conf_3;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_3,
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_3, algorithm>,
block_size_conf_3, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_3 == max_active_blocks_conf_1) {
*success = true;
return LaunchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_3, algorithm>(
stream, load_y, load_dy, store, smem, rows, cols);
}
int max_active_blocks_conf_2;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_2,
SoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_2, algorithm>,
block_size_conf_2, smem);
if (err != cudaSuccess) {
return err;
}
}
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
*success = true;
return LaunchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_2, algorithm>(
stream, load_y, load_dy, store, smem, rows, cols);
}
*success = true;
return LaunchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size,
block_size_conf_1, algorithm>(stream, load_y, load_dy,
store, smem, rows, cols);
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
struct TryDispatchSoftmaxGradBlockSMemImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols, bool* success) {
if (cols % 2 == 0) {
return TryDispatchSoftmaxGradBlockSMemImplBlockSize<LOAD_Y, LOAD_DY, STORE, ComputeType,
2, algorithm>(
stream, load_y, load_dy, store, rows, cols, success);
} else {
return TryDispatchSoftmaxGradBlockSMemImplBlockSize<LOAD_Y, LOAD_DY, STORE, ComputeType,
1, algorithm>(
stream, load_y, load_dy, store, rows, cols, success);
}
}
};
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
inline cudaError_t TryDispatchSoftmaxGradBlockSMemImpl(cudaStream_t stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols,
bool* success) {
return TryDispatchSoftmaxGradBlockSMemImplPackSize<LOAD_Y, LOAD_DY, STORE, ComputeType,
algorithm>()(stream, load_y, load_dy, store,
rows, cols, success);
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
int block_size, Algorithm algorithm>
__global__ void SoftmaxGradBlockUncachedImpl(LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
const int tid = threadIdx.x;
assert(cols % pack_size == 0);
const int num_packs = cols / pack_size;
for (int64_t row = blockIdx.x; row < rows; row += gridDim.x) {
ComputeType thread_sum = 0;
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType y_pack[pack_size];
ComputeType dy_pack[pack_size];
load_y.template load<pack_size>(y_pack, row, pack_id * pack_size);
load_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
thread_sum += y_pack[i] * dy_pack[i];
} else if (algorithm == Algorithm::kLogSoftmax) {
thread_sum += dy_pack[i];
} else {
__trap();
}
}
}
const ComputeType row_sum = BlockAllReduce<SumOp, ComputeType, block_size>(thread_sum);
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
ComputeType y_pack[pack_size];
ComputeType dy_pack[pack_size];
load_y.template load<pack_size>(y_pack, row, pack_id * pack_size);
load_dy.template load<pack_size>(dy_pack, row, pack_id * pack_size);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
if (algorithm == Algorithm::kSoftmax) {
dy_pack[i] = (dy_pack[i] - row_sum) * y_pack[i];
} else if (algorithm == Algorithm::kLogSoftmax) {
dy_pack[i] -= Exp(y_pack[i]) * row_sum;
} else {
__trap();
}
}
store.template store<pack_size>(dy_pack, row, pack_id * pack_size);
}
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType, int pack_size,
Algorithm algorithm>
inline cudaError_t LaunchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
constexpr int block_size = 1024;
constexpr int waves = 32;
int grid_dim_x;
{
cudaError_t err = GetNumBlocks(block_size, rows, waves, &grid_dim_x);
if (err != cudaSuccess) {
return err;
}
}
SoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, pack_size, block_size,
algorithm>
<<<grid_dim_x, block_size, 0, stream>>>(load_y, load_dy, store, rows, cols);
return cudaPeekAtLastError();
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
struct DispatchSoftmaxGradBlockUncachedImplPackSize {
cudaError_t operator()(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols % 2 == 0 && cols > kWarpSize) {
return LaunchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, 2,
algorithm>(stream, load_y, load_dy, store,
rows, cols);
} else {
return LaunchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType, 1,
algorithm>(stream, load_y, load_dy, store,
rows, cols);
}
}
};
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType,
Algorithm algorithm>
inline cudaError_t DispatchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOAD_Y load_y,
LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxGradBlockUncachedImplPackSize<LOAD_Y, LOAD_DY, STORE, ComputeType,
algorithm>()(stream, load_y, load_dy, store,
rows, cols);
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols <= 1024) {
return DispatchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kSoftmax>(stream, load_y, load_dy, store,
rows, cols);
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err = TryDispatchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE,
ComputeType, Algorithm::kSoftmax>(
stream, load_y, load_dy, store, rows, cols, &dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kSoftmax>(
stream, load_y, load_dy, store, rows, cols);
}
return cudaSuccess;
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kSoftmax>(stream, load_y, load_dy, store,
rows, cols);
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<!std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
if (cols <= 1024) {
return DispatchSoftmaxGradWarpImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kLogSoftmax>(stream, load_y, load_dy, store,
rows, cols);
} else {
bool dispatch_smem_impl_success;
{
cudaError_t err =
TryDispatchSoftmaxGradBlockSMemImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kLogSoftmax>(
stream, load_y, load_dy, store, rows, cols, &dispatch_smem_impl_success);
if (err != cudaSuccess) {
return err;
}
}
if (!dispatch_smem_impl_success) {
return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kLogSoftmax>(
stream, load_y, load_dy, store, rows, cols);
}
return cudaSuccess;
}
}
template <typename LOAD_Y, typename LOAD_DY, typename STORE, typename ComputeType>
inline typename std::enable_if<std::is_same<ComputeType, double>::value, cudaError_t>::type
DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store,
const int64_t rows, const int64_t cols) {
return DispatchSoftmaxGradBlockUncachedImpl<LOAD_Y, LOAD_DY, STORE, ComputeType,
Algorithm::kLogSoftmax>(stream, load_y, load_dy,
store, rows, cols);
}
} // namespace softmax
} // namespace fastfold
#endif // FASTFOLD_SOFTMAX_H_
\ No newline at end of file
#include <torch/extension.h>
at::Tensor softmax(at::Tensor input, int rows, int cols);
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor input, int rows, int cols);
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, int rows, int cols,
float scale);
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor input, at::Tensor mask,
int rows, int cols, float scale);
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
int rows, int cols, float scale);
at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tensor input,
at::Tensor mask, at::Tensor bias, int rows,
int cols, float scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &softmax, "Softmax forward (CUDA)");
m.def("backward_affine", &softmax_gradient, "Softmax backward (CUDA)");
m.def("fused_scale_mask_softmax_forward", &fused_scale_mask_softmax_forward,
"Softmax forward (CUDA)");
m.def("fused_scale_mask_softmax_backward", &fused_scale_mask_softmax_backward,
"Softmax forward (CUDA)");
m.def("fused_scale_mask_bias_softmax_forward", &fused_scale_mask_bias_softmax_forward,
"Softmax forward (CUDA)");
m.def("fused_scale_mask_bias_softmax_backward", &fused_scale_mask_bias_softmax_backward,
"Softmax forward (CUDA)");
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment