"web/vscode:/vscode.git/clone" did not exist on "b016e2769f0a16fcba21c020023413cad68f704b"
Commit 1fcbe6f0 authored by Tri Dao's avatar Tri Dao
Browse files

First release

parents
Alpha release of FlashAttention.
To compile:
```
cd csrc/stream_attn
python setup.py install
```
Interface: `streaming_attention.py`
Contact: `trid@stanford.edu`
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
ctx.first_axis_dim = input.shape[0]
assert input.ndim == 2
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(input, 0, repeat(indices, 'z -> z d', d=input.shape[1]))
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
grad_input = torch.zeros([ctx.first_axis_dim, *grad_output.shape[1:]],
device=grad_output.device, dtype=grad_output.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output)
return grad_input, None
index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim == 2
output = torch.zeros(first_axis_dim, values.shape[1], device=values.device,
dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
def unpad_input(hidden_states, attention_mask):
"""
Arguments:
hidden_states: (batch, seqlen, dim)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return:
hidden_states: (total_nnz, dim), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (index_first_axis(rearrange(hidden_states, 'b s d -> (b s) d'), indices), indices,
cu_seqlens, max_seqlen_in_batch)
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, dim), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
Return:
hidden_states: (batch, seqlen, dim)
"""
dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, '(b s) d -> b s d', b=batch)
Our implementation uses Apex's
[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code
as a starting point.
We thank [Young-jun Ko](https://yjk21.github.io/) for the in-depth explanation of his FMHA implementation
and for his thoughtful answers to our questions about CUDA.
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "fmha.h"
void set_params(Fused_multihead_attention_fprop_params &params,
// sizes
const size_t b,
const size_t s,
const size_t h,
const size_t d,
// device pointers
void *qkv_packed_d,
void *cu_seqlens_d,
void *o_packed_d,
void *o_tmp_d,
void *do_packed_d,
void *s_d,
void *softmax_lse_d,
void *dsoftmax_sum_d,
float p_dropout,
float softmax_scale,
bool is_causal) {
Data_type acc_type = DATA_TYPE_FP32;
Data_type data_type = DATA_TYPE_FP16;
// Reset the parameters
memset(&params, 0, sizeof(params));
// Set the pointers and strides.
params.qkv_ptr = qkv_packed_d;
params.qkv_stride_in_elts = h * 3 * d;
params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type);
params.o_ptr = o_packed_d;
params.o_stride_in_elts = h * d;
params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);
params.do_ptr = do_packed_d;
params.o_tmp_ptr = o_tmp_d;
params.cu_seqlens = static_cast<int *>(cu_seqlens_d);
// S = softmax(P)
params.s_ptr = s_d;
params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type);
// Softmax sum
params.softmax_lse_ptr = softmax_lse_d;
params.dsoftmax_sum = dsoftmax_sum_d;
// Set the dimensions.
params.b = b;
params.h = h;
params.s = s;
params.d = d;
// Set the different scale values.
// const float scale_bmm1 = 1.f / sqrtf(d);
const float scale_bmm1 = softmax_scale;
constexpr float scale_softmax = 1.f;
constexpr float scale_bmm2 = 1.f;
params.scale_bmm1f = scale_bmm1;
set_alpha(params.scale_bmm1, scale_bmm1, data_type);
set_alpha(params.scale_softmax, scale_softmax, acc_type);
set_alpha(params.scale_bmm2, scale_bmm2, data_type);
// Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout;
// Convert p from float to int so we don't have to convert the random uint to float to compare.
// [Minor] We want to round down since when we do the comparison we use <= instead <
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
params.rp_dropout = 1.f / params.p_dropout;
TORCH_CHECK(p_dropout < 1.f);
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
params.is_causal = is_causal;
}
std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
const int max_seq_len,
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0);
auto stream = at::cuda::getCurrentCUDAStream().stream();
bool is_dropout = p_dropout > 0.0;
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int base_N = head_size == 128 ? 128 : 256;
// int base_N = 256;
int seq_len = 512;
if( max_seq_len <= 128 ) {
seq_len = 128;
} else if( max_seq_len <= 256 ) {
seq_len = 256;
} else {
seq_len = ((max_seq_len + base_N - 1) / base_N) * base_N;
}
bool loop = seq_len > base_N;
auto opts = qkv.options();
auto ctx = torch::empty({ total, num_heads, head_size }, opts);
at::Tensor o_tmp;
if (loop) {
o_tmp = torch::empty({total, num_heads, head_size}, opts.dtype(at::kFloat));
}
auto softmax_lse = torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));
// auto softmax_lse = torch::full({batch_size, num_heads, seq_len}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
at::Tensor s;
if (return_softmax) {
s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
// s = torch::ones({ batch_size, num_heads, seq_len, seq_len }, opts) * 10000.0;
}
if( zero_tensors ) {
ctx.zero_();
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
if (loop) { o_tmp.zero_(); }
if (return_softmax) {s.zero_();}
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
set_params(launch_params.params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
ctx.data_ptr(),
loop ? o_tmp.data_ptr() : nullptr,
nullptr,
return_softmax ? s.data_ptr() : nullptr,
softmax_lse.data_ptr(),
nullptr,
p_dropout,
softmax_scale,
is_causal);
run_fmha_fp16_sm80(launch_params, /*configure=*/ true);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = launch_params.elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
if( is_dropout ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
}
run_fmha_fp16_sm80(launch_params, /*configure=*/false);
std::vector<at::Tensor> result = {ctx, softmax_lse};
if (return_softmax) {result.push_back(s);}
return result;
}
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total x num_heads x head_size
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop
const float softmax_scale,
const int max_seq_len, // max sequence length to choose the kernel
const bool zero_tensors,
const bool is_causal,
c10::optional<at::Generator> gen_
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0);
auto launch = &run_fmha_dgrad_fp16_sm80;
bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(dout.dtype() == torch::kFloat16);
TORCH_CHECK(softmax.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda());
TORCH_CHECK(cu_seqlens.is_cuda());
TORCH_CHECK(qkv.is_contiguous());
TORCH_CHECK(cu_seqlens.is_contiguous());
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
// int base_N = head_size == 16 ? 512 : (head_size == 128 ? 128 : 256);
int base_N = head_size == 128 ? 128 : 256;
int seq_len = 512;
if( max_seq_len <= 128 ) {
seq_len = 128;
} else if( max_seq_len <= 256 ) {
seq_len = 256;
} else {
seq_len = ((max_seq_len + base_N - 1) / base_N) * base_N;
}
bool loop = seq_len > base_N;
auto dqkv = torch::empty_like(qkv);
auto opts = qkv.options();
// auto softmax_lse =
// torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));
auto softmax_d = torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));
// softmax.zero_();
// torch::nn::init::ones_(softmax);
// torch::nn::init::ones_(dqkv);
at::Tensor dq_tmp;
if (loop) {
dq_tmp = torch::empty({total, num_heads, head_size}, opts.dtype(at::kFloat));
}
if( zero_tensors ) {
dqkv.zero_();
softmax_d.zero_();
if (loop) { dq_tmp.zero_(); }
}
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
out.data_ptr(),
loop ? dq_tmp.data_ptr() : nullptr,
dout.data_ptr(),
softmax.data_ptr(), // softmax gets overwritten by dP!
softmax_lse.data_ptr(),
softmax_d.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// We're gonna reset the rng state in Python after this kernel, so the counter offset
// here doesn't matter at all. We just choose an arbitrary number;
int64_t counter_offset = 4;
if( is_dropout ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
Data_type acc_type = DATA_TYPE_FP32;
params.dqkv_ptr = dqkv.data_ptr();
launch(params, stream);
return { dqkv, softmax, softmax_d };
// std::vector<at::Tensor> result = {dqkv, softmax, softmax_d};
// if (loop) {
// result.push_back(dq_tmp);
// }
// return result;
}
std::vector<at::Tensor>
mha_fwd_block(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const at::Tensor &blockmask, // (seqlen / 256, seqlen / 16)
const float p_dropout,
const int max_seq_len,
const float softmax_scale,
const bool is_causal,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0);
auto stream = at::cuda::getCurrentCUDAStream().stream();
bool is_dropout = p_dropout > 0.0;
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
bool loop = false;
int seq_len = 256;
if( max_seq_len > 256 ) {
seq_len = ((max_seq_len + 256 - 1) / 256) * 256;
loop = true;
}
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(blockmask.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(blockmask.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
TORCH_CHECK(blockmask.dim() == 2);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64);
auto opts = qkv.options();
auto ctx = torch::zeros({ total, num_heads, head_size }, opts);
at::Tensor o_tmp;
if (loop) {
// o_tmp = torch::zeros({total, num_heads, head_size}, opts.dtype(at::kFloat));
o_tmp = torch::empty({total, num_heads, head_size}, opts.dtype(at::kFloat));
}
// auto softmax_lse = torch::full({batch_size, num_heads, seq_len}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
auto softmax_lse = torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));
at::Tensor s;
if (return_softmax) {
s = torch::zeros({ batch_size, num_heads, seq_len, seq_len }, opts);
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
set_params(launch_params.params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
ctx.data_ptr(),
loop ? o_tmp.data_ptr() : nullptr,
nullptr,
return_softmax ? s.data_ptr() : nullptr,
softmax_lse.data_ptr(),
nullptr,
p_dropout,
softmax_scale,
is_causal);
launch_params.params.blockmask = static_cast<int *>(blockmask.data_ptr());
run_fmha_block_fp16_sm80(launch_params, /*configure=*/ true);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = launch_params.elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
if( is_dropout ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
}
run_fmha_block_fp16_sm80(launch_params, /*configure=*/false);
std::vector<at::Tensor> result = {ctx, softmax_lse};
if (return_softmax) {result.push_back(s);}
return result;
}
std::vector<at::Tensor>
mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total x num_heads x head_size
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
const at::Tensor &cu_seqlens, // b+1
const at::Tensor &blockmask, // (seqlen / 256, seqlen / 16)
const float p_dropout, // probability to drop
const float softmax_scale,
const int max_seq_len, // max sequence length to choose the kernel
const bool is_causal,
c10::optional<at::Generator> gen_
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0);
bool loop = false;
int seq_len = 256;
auto launch = &run_fmha_block_dgrad_fp16_sm80;
if (max_seq_len > 256) {
seq_len = ((max_seq_len + 256 - 1) / 256) * 256;
loop = true;
}
bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(dout.dtype() == torch::kFloat16);
TORCH_CHECK(softmax.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(blockmask.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda());
TORCH_CHECK(cu_seqlens.is_cuda());
TORCH_CHECK(blockmask.is_cuda());
TORCH_CHECK(qkv.is_contiguous());
TORCH_CHECK(cu_seqlens.is_contiguous());
TORCH_CHECK(blockmask.is_contiguous());
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
TORCH_CHECK(blockmask.dim() == 2);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64);
auto dqkv = torch::zeros_like(qkv);
auto opts = qkv.options();
auto softmax_d = torch::empty({batch_size, num_heads, seq_len}, opts.dtype(at::kFloat));
at::Tensor dq_tmp;
if (loop) {
// dq_tmp = torch::zeros({total, num_heads, head_size}, opts.dtype(at::kFloat));
dq_tmp = torch::empty({total, num_heads, head_size}, opts.dtype(at::kFloat));
}
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
out.data_ptr(),
loop ? dq_tmp.data_ptr() : nullptr,
dout.data_ptr(),
softmax.data_ptr(), // softmax gets overwritten by dP!
softmax_lse.data_ptr(),
softmax_d.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
params.blockmask = static_cast<int *>(blockmask.data_ptr());
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// We're gonna reset the rng state in Python after this kernel, so the counter offset
// here doesn't matter at all. We just choose an arbitrary number;
int64_t counter_offset = 4;
if( is_dropout ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
Data_type acc_type = DATA_TYPE_FP32;
params.dqkv_ptr = dqkv.data_ptr();
launch(params, stream);
return { dqkv, softmax, softmax_d };
// std::vector<at::Tensor> result = {dqkv, softmax, softmax_d};
// if (loop) {
// result.push_back(dq_tmp);
// }
// return result;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("fwd_block", &mha_fwd_block, "Forward pass (blocksparse)");
m.def("bwd_block", &mha_bwd_block, "Backward pass (blocksparse)");
}
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages
import subprocess
import sys
import warnings
import os
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
+ "In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def raise_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None:
return
raise RuntimeError(
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
print(
"\nWarning: Torch did not find available GPUs on this system.\n",
"If your intention is to cross-compile, this is not an error.\n"
"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
"Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
"If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
cmdclass = {}
ext_modules = []
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
raise_if_cuda_home_none("--streamattn")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) < 11:
raise RuntimeError("--streamattn only supported on SM80+")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
ext_modules.append(
CUDAExtension(
name="stream_attn_cuda",
sources=[
"fmha_api.cpp",
"src/fmha_fprop_fp16_kernel.sm80.cu",
"src/fmha_dgrad_fp16_kernel_loop.sm80.cu",
"src/fmha_block_fprop_fp16_kernel.sm80.cu",
"src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
],
extra_compile_args={
"cxx": ["-O3"] + generator_flag,
"nvcc": append_nvcc_threads(
[
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v",
"-lineinfo"
]
+ generator_flag
+ cc_flag
),
},
include_dirs=[
this_dir,
os.path.join(this_dir, "src"),
],
)
)
setup(
name="stream_attn_cuda",
version="0.1",
description="Streaming attention",
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension} if ext_modules else {},
)
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <fmha_utils.h>
constexpr int TOTAL_DIM = 0;
constexpr int THREE_DIM = 1;
constexpr int H_DIM = 2;
constexpr int D_DIM = 3;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
// The QKV matrices.
void * __restrict__ qkv_ptr;
// The stride between rows of the Q, K and V matrices.
// size_t qkv_stride_in_elts;
// size_t qkv_stride_in_bytes;
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
uint32_t qkv_stride_in_elts;
uint32_t qkv_stride_in_bytes;
// The number of heads.
int h;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices.
void * __restrict__ dqkv_ptr;
// Temporary for dKV.
void * __restrict__ dkv_ptr;
// The O matrix (output).
void * __restrict__ o_ptr;
// The stride between rows of O.
// size_t o_stride_in_elts;
// size_t o_stride_in_bytes;
uint32_t o_stride_in_elts;
uint32_t o_stride_in_bytes;
// The pointer to the O_tmp matrix, which holds O intermediate value during
// the loop;
void *__restrict__ o_tmp_ptr;
// The dO matrix .
void * __restrict__ do_ptr;
// The pointer to the S matrix, overwritten by the dP matrix (bwd).
void * __restrict__ s_ptr;
// The stride between rows of the S matrix.
// int64_t s_stride_in_bytes;
uint32_t s_stride_in_bytes;
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
// The pointer to the softmax d sum.
void * __restrict__ dsoftmax_sum;
// The dimensions.
int b, s, d;
// The scaling factors for the kernel.
float scale_bmm1f;
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens;
int *__restrict__ blockmask;
// The dropout probability (probability of keeping an activation).
float p_dropout;
uint32_t p_dropout_in_uint;
uint16_t p_dropout_in_uint16_t;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
// Scale factor of 1 / (1 - p_dropout), in half2.
uint32_t scale_dropout;
// Random state.
at::PhiloxCudaState philox_args;
bool is_causal;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_params>
struct Launch_params{
Launch_params(cudaDeviceProp * props_,
cudaStream_t stream_,
bool is_dropout_,
bool return_softmax_)
: elts_per_thread(0)
, props(props_)
, stream(stream_)
, is_dropout(is_dropout_)
, return_softmax(return_softmax_) {
}
size_t elts_per_thread;
cudaDeviceProp * props;
cudaStream_t stream;
bool is_dropout;
bool return_softmax;
Kernel_params params;
int num_full_heads;
int num_main_groups;
int heads_last_wave;
int main_steps;
int rest_steps;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_block_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_block_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <fmha/utils.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ >
struct Fragment_base_ {
// The data type.
using Data_type = Data_type_;
// default input type
using Input_type_ = Data_type_;
// Does it store the array of elements.
static constexpr bool HAS_ELTS = BITS_PER_ELT_ >= 8;
// The number of elements.
static constexpr int NUM_ELTS = NUM_ELTS_;
// The size of element in bits.
static constexpr int BITS_PER_ELT = BITS_PER_ELT_;
// The size of byte of a single register.
static constexpr int BYTES_PER_REG = 4;
// The size in bits.
static constexpr int BITS_PER_REG = BYTES_PER_REG * 8;
// The number of registers needed to store the fragment.
static constexpr int NUM_REGS = DivUpConstexpr(NUM_ELTS * BITS_PER_ELT, BITS_PER_REG);
// The size in bytes (as returned by sizeof(Fragment_base<>).
static constexpr int SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG;
// The alignment.
static constexpr int ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : MinConstexpr(NUM_REGS * BYTES_PER_REG, 16);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The type of the elements.
typename Data_type_,
// The number of elements.
int NUM_ELTS_,
// The alignment if you want to force a value -- use 0 otherwise.
int ALIGNMENT_ = 0,
// The base class.
typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_>
>
struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
// The size of a load/store.
static constexpr int BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t);
// Clear the fragment. Using PTX in that code seems to produce better SASS...
inline __device__ void clear() {
#pragma unroll
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : );
}
}
// Immutable access to a register.
inline __device__ const uint32_t& reg(int ii) const {
return this->regs_[ii];
}
// Mutable access to a register.
inline __device__ uint32_t& reg(int ii) {
return this->regs_[ii];
}
uint32_t regs_[Base_::NUM_REGS];
// Immutable access to the elements.
inline __device__ const Data_type_& elt(int ii) const {
return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];
}
// Mutable access to the elements.
inline __device__ Data_type_& elt(int ii) {
return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];
}
// Immutable access to the elements with a cast.
template< typename Cast_type >
inline __device__ const Cast_type& elt_as(int ii) const {
return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];
}
// Mutable access to the elements.
template< typename Cast_type >
inline __device__ Cast_type& elt_as(int ii) {
return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];
}
// Add another fragment.
inline __device__ void add(const Fragment &other) {
// TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS?
// Also are we doing int addition or __half2 addition?
#pragma unroll
for( int ii = 0; ii < NUM_ELTS_; ++ii ) {
this->elt(ii) += other.elt(ii);
}
}
// Multiply by another fragment.
inline __device__ void hmul(const Fragment &other) {
#pragma unroll
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii));
}
}
inline __device__ void hrelu_() {
#pragma unroll
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
this->reg(ii) = fmha::hrelu2(this->reg(ii));
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Layout >
struct Fragment_a : public Fragment<uint16_t, 8> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Layout >
struct Fragment_b : public Fragment<uint16_t, 8> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fragment_accumulator : public Fragment<float, 8> {
// The base class.
using Base = Fragment<float, 8>;
// Add two fragments.
template< typename Other_fragment_ >
inline __device__ void add(const Other_fragment_ &other) {
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
this->elt(ii) = this->elt(ii) + other.elt(ii);
}
}
inline __device__ void mul_(const float other) {
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
this->elt(ii) *= other;
}
}
// Do the HMMA.
template< typename Layout_a, typename Layout_b >
inline __device__ void mma(const Fragment_a<Layout_a> &a,
const Fragment_b<Layout_b> &b) {
asm volatile( \
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
" {%0, %1, %2, %3}, \n" \
" {%4, %5, %6, %7}, \n" \
" {%8, %9}, \n" \
" {%0, %1, %2, %3}; \n" \
: "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3))
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
, "r"(b.reg(0)), "r"(b.reg(1)));
asm volatile( \
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
" {%0, %1, %2, %3}, \n" \
" {%4, %5, %6, %7}, \n" \
" {%8, %9}, \n" \
" {%0, %1, %2, %3}; \n" \
: "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7))
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
, "r"(b.reg(2)), "r"(b.reg(3)));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Fragment, int M, int N >
inline __device__ void clear(Fragment (&frag)[M][N]) {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < N; ++ni ) {
frag[mi][ni].clear();
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Accumulator_type, int WARPS_K >
struct Clear_accumulator {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int WARPS_K >
struct Clear_accumulator<float, WARPS_K> {
template< typename Acc, int M, int N >
static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {
fmha::clear(acc);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B, int M, int N>
inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < N; ++ni ) {
acc[mi][ni].mma(a[mi], b[ni]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The number of rows in the CTA tile.
int M_,
// The number of cols in the CTA tile.
int N_,
// The number of elements in the the K dimension of the GEMM loop.
int K_,
// The number of rows of warps.
int WARPS_M_,
// The number of cols of warps.
int WARPS_N_,
// The number of warps in the K dimension of the GEMM loop.
int WARPS_K_>
struct Cta_tile_ {
static constexpr int M = M_, N = N_, K = K_;
// The number of warps.
static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_;
// The number of warps per CTA.
static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K;
// The number of threads per warp.
static constexpr int THREADS_PER_WARP = 32;
// The number of threads per CTA.
static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Hmma_tile {
// The number of elements computed with a single warp-MMA.
static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16;
// The number of elements computed with a single CTA-MMA.
static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K;
// The number of MMAs needed to compute the GEMM.
static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA),
MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA),
MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA);
// // The number of elements computed per warp.
// static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA,
// N_PER_WARP = MMAS_N * N_PER_MMA,
// K_PER_WARP = MMAS_K * K_PER_MMA;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using A_type = uint16_t;
using B_type = uint16_t;
using C_type = uint16_t;
using Accumulator_type = float;
using Epilogue_type = float;
constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;
constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;
constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>
using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile_>
using Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M,
Cta_tile_::N,
Next_power_of_two<Cta_tile_::K>::VALUE,
Cta_tile_::WARPS_M,
Cta_tile_::WARPS_N,
Cta_tile_::WARPS_K>;
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile_,
// The number of bits per element.
int BITS_PER_ELEMENT,
// The number of rows of Q, K or V loaded by this tile.
int ROWS_,
// The number of columns.
int COLS,
// The number of matrics.
int NUM_MATS = 3
>
struct Gmem_tile_qkv {
using Cta_tile = Cta_tile_;
// The size of each LDG.
static constexpr int BYTES_PER_LDG = 16;
// The size of a row in bytes.
static constexpr int BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8;
// The number of threads to load a "row" of the matrix.
static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG;
static constexpr int ROWS = ROWS_;
// The number of "rows" loaded per LDG.
static constexpr int ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
// The number of LDGs needed to load a chunk of the Q matrix.
static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG);
// Ctor.
template< typename Params, typename BInfo >
inline __device__ Gmem_tile_qkv(const Params &params, const int qkv_offset, const BInfo &binfo, const int tidx)
: params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes)
, actual_seqlen(binfo.actual_seqlen)
, qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr))
, tidx_(tidx) {
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % THREADS_PER_ROW;
// Store the row as we need it to disable the loads.
// TD [2022-04-16]: To minimize registers, we'll recompute row_ instead of storing it
// row_ = row;
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
uint32_t row_offset = (uint32_t)row * params.qkv_stride_in_bytes;
// Add the block index.
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
row_offset += (uint32_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
// Assemble the final pointer.
qkv_ptr_ += row_offset + col * BYTES_PER_LDG;
}
// Store data to shared memory.
template< typename Smem_tile >
inline __device__ void commit(Smem_tile &smem_tile) {
smem_tile.store(fetch_);
}
inline __device__ void load() {
int row_ = tidx_ / THREADS_PER_ROW;
const void *ptrs[LDGS];
uint32_t preds[LDGS];
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
// ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
ptrs[ii] = qkv_ptr_ + (uint32_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));
fetch_[ii] = make_uint4(0, 0, 0, 0);
}
// not packing predicates removes restrictions (e.g. FP16 384, 4 warps)
Ldg_functor<uint4, LDGS> fct(fetch_, ptrs);
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
fct.load(ii, preds[ii]);
}
}
// Store data to memory.
inline __device__ void store(const uint4 (&data)[LDGS]) {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
// char *ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
char *ptr = qkv_ptr_ + (uint32_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
fmha::stg(ptr, data[ii]);
}
}
}
// Move the pointer to the next location.
inline __device__ void move() {
// qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_;
qkv_ptr_ += (uint32_t)ROWS * params_qkv_stride_in_bytes_;
actual_seqlen -= ROWS;
}
inline __device__ void move(int steps) {
// qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps;
qkv_ptr_ += (uint32_t)ROWS * params_qkv_stride_in_bytes_ * steps;
actual_seqlen -= ROWS * steps;
}
// The stride between rows for the QKV matrice.
// int64_t params_qkv_stride_in_bytes_;
uint32_t params_qkv_stride_in_bytes_;
// The pointer.
char *qkv_ptr_;
// The fetch registers.
uint4 fetch_[LDGS];
// Keep track of the row the thread is processing as we move the tile.
// int row_;
const int tidx_;
// The length of the sequence loaded by that memory tile.
int actual_seqlen;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Cta_tile,
int BYTES_PER_ELEMENT = 2
>
struct Gmem_tile_o {
static_assert(BYTES_PER_ELEMENT == 2 || BYTES_PER_ELEMENT == 4);
// The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The size of each element.
// static constexpr int BYTES_PER_ELEMENT = 2;
// The size of each STG.
static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 4;
static constexpr int COLS = Cta_tile::N;
// The size of a row in bytes.
static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT;
// The number of threads to store a "row" of the matrix.
static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG;
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
static constexpr int ROWS = Cta_tile::M;
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA;
// The number of outter loop for the stores.
static constexpr int LOOPS = ROWS / ROWS_PER_LOOP;
// The number of "rows" stored per STG.
static constexpr int ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
// Do we have to guard against partial writes/reads.
static constexpr bool HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0;
// The number of STGs needed to store a chunk of the Q matrix.
static constexpr int STGS_PER_LOOP = DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_STG);
// The number of STGs needed to store a chunk of the Q matrix in total.
static constexpr int STGS = STGS_PER_LOOP * LOOPS;
// Ctor.
template<typename BInfo>
// inline __device__ Gmem_tile_o(void *ptr, const size_t stride_in_elts, const BInfo &binfo, const int tidx)
inline __device__ Gmem_tile_o(void *ptr, const uint32_t stride_in_elts, const BInfo &binfo, const int tidx)
: stride_in_bytes_(stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen_(binfo.actual_seqlen)
, actual_seqlen(binfo.actual_seqlen)
, ptr_(reinterpret_cast<char *>(ptr))
, tidx_(tidx) {
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % THREADS_PER_ROW;
// Store the row as we need it to disable loads.
// row_ = row;
// The row offset in the batched GEMM.
// int64_t row_offset = (int64_t)row * stride_in_bytes_ + binfo.bidx * BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)row * stride_in_bytes_ + binfo.bidx * BYTES_PER_ROW;
// Assemble the final pointer.
ptr_ += row_offset + col * BYTES_PER_STG;
// Is that thread active on the last STG?
if( HAS_INCOMPLETE_STG ) {
is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M;
}
}
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_o(const Params &params, const BInfo &binfo, const int tidx)
: Gmem_tile_o(params.o_ptr, params.o_stride_in_elts, binfo, tidx) {}
// Store data to global memory.
inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
// if( this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_ ) {
// break;
if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen ) {
break;
}
if (BYTES_PER_ELEMENT == 4) {
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->stride_in_bytes_, src[ii]);
}
} else if (BYTES_PER_ELEMENT == 2) {
float x = reinterpret_cast<const float &>(src[ii].x);
float y = reinterpret_cast<const float &>(src[ii].y);
float z = reinterpret_cast<const float &>(src[ii].z);
float w = reinterpret_cast<const float &>(src[ii].w);
uint2 out = float4_to_half4(x, y, z, w);
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->stride_in_bytes_, out);
}
}
}
}
// Store data to global memory.
inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) {
static_assert(BYTES_PER_ELEMENT == 4);
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen ) {
break;
}
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->stride_in_bytes_);
}
}
}
// Move the pointer to the next location.
inline __device__ void move() {
// row_ += ROWS;
// ptr_ += (int64_t)ROWS * stride_in_bytes_;
ptr_ += (uint32_t)ROWS * stride_in_bytes_;
actual_seqlen -= ROWS;
}
inline __device__ void move(const int steps) {
// row_ += ROWS * steps;
// ptr_ += (int64_t)ROWS * stride_in_bytes_ * steps;
ptr_ += (uint32_t)ROWS * stride_in_bytes_ * steps;
actual_seqlen -= ROWS * steps;
}
// The stride between rows for the QKV matrice.
// int64_t stride_in_bytes_;
uint32_t stride_in_bytes_;
// The pointer.
char *ptr_;
// Is the thread active for the last STG?
int is_active_for_last_stg_;
// Keep track of the row to disable loads.
// int row_;
// The length of the sequence loaded by that memory tile.
const int actual_seqlen_;
int actual_seqlen;
const int tidx_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, int BYTES_PER_ELEMENT >
struct Gmem_tile_mma_sd {
// The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// Each STG stores 8 elements.
static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 8;
// The number of MMAs in the M dimension.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
// The number of MMAs in the N dimension.
static constexpr int MMAS_N = Mma_tile::MMAS_N;
// The number of rows computed per MMA per thread block.
static constexpr int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA;
// The number of cols computed per MMA per thread block.
static constexpr int N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA;
// The number of threads per block.
static constexpr int THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA;
// The size of each row in bytes. I.e. how many bytes are stored per STG.
static constexpr int BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG;
// The distance between elements stored per loop (in bytes).
static constexpr int LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW;
// The type of elements stored per STG.
using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;
// Ctor.
template<typename Params>
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int bidb, const int bidh, const int tidx)
: ptr_(static_cast<char *>(ptr)) {
// The block index.
// size_t bidx = bidb * params.h + bidh;
uint32_t bidx = bidb * params.h + bidh;
// The distance between two blocks (in bytes).
// const size_t block_stride_bytes = params.s * params.s * BYTES_PER_ELEMENT;
const uint32_t block_stride_bytes = params.s * params.s * BYTES_PER_ELEMENT;
// Set store location for each thread at the beginning of the loop
ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG;
}
// Store to global memory.
inline __device__ void store(const Type &data, const int mi, const int ni) {
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
fmha::stg(ptr_ + offset, data);
}
// Load from global memory.
inline __device__ void load(Type &data, const int mi, const int ni) {
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
fmha::ldg(data, ptr_ + offset);
}
// Move to the next tile.
inline __device__ void move() {
ptr_ += LOOP_STRIDE_BYTES;
}
inline __device__ void move(const int steps) {
ptr_ += LOOP_STRIDE_BYTES * steps;
}
// The pointer in global memory.
char *ptr_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >
struct Gmem_tile_mma_s : public Base {
// The number of mmas in the vertical dimension.
static constexpr int M = Base::MMAS_M;
// The number of mmas in the horizontal dimension.
static constexpr int N = Base::MMAS_N;
// The type of the vectors stored by each STG.
using Type = typename Base::Type;
// Ctor.
template< typename Params, typename Block_info >
inline __device__ Gmem_tile_mma_s(const Params &params, const Block_info& binfo, const int tidx)
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
}
// Store to global memory.
template<typename Mask>
inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
float tmp00 = softmax[2 * mi + 0][4 * ni + 0];
float tmp01 = softmax[2 * mi + 0][4 * ni + 1];
float tmp02 = softmax[2 * mi + 0][4 * ni + 2];
float tmp03 = softmax[2 * mi + 0][4 * ni + 3];
float tmp10 = softmax[2 * mi + 1][4 * ni + 0];
float tmp11 = softmax[2 * mi + 1][4 * ni + 1];
float tmp12 = softmax[2 * mi + 1][4 * ni + 2];
float tmp13 = softmax[2 * mi + 1][4 * ni + 3];
uint4 dst;
dst.x = fmha::float2_to_half2(tmp00, tmp01);
dst.y = fmha::float2_to_half2(tmp02, tmp03);
dst.z = fmha::float2_to_half2(tmp10, tmp11);
dst.w = fmha::float2_to_half2(tmp12, tmp13);
if( mask.is_valid(mi, ni, 0, 0) ) {
Base::store(dst, mi, ni);
}
}
}
}
// Store to global memory.
template<typename Mask, typename Fragment>
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 dst;
dst.x = frag[ni][mi].reg(0);
dst.y = frag[ni][mi].reg(2);
dst.z = frag[ni][mi].reg(1);
dst.w = frag[ni][mi].reg(3);
if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni);
}
}
}
}
// Load from global memory.
template<typename Mask>
inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
regs[mi][ni] = make_uint4(0, 0, 0, 0);
if( mask.any_valid(mi, ni) ) {
Base::load(regs[mi][ni], mi, ni);
}
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The base class.
typename Base = fmha::Gmem_tile_qkv<Cta_tile, fmha::BITS_PER_ELEMENT_A, Cta_tile::M, Cta_tile::K>
>
struct Gmem_tile_dout : public Base {
// Ctor.
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_dout(void *ptr, const Params &params, const BInfo &binfo, int tidx)
: Base(params, 0, binfo, tidx) {
// this->qkv_ptr_ = reinterpret_cast<char *>(params.do_ptr);
this->qkv_ptr_ = static_cast<char *>(ptr);
this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes; // needed for move
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / Base::THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % Base::THREADS_PER_ROW;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
// int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
// int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)row * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
// Assemble the final pointer.
this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, typename Base = fmha::Gmem_tile_o<Cta_tile> >
struct Gmem_tile_dq : public Base {
// Ctor.
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_dq(const Params &params, const int qkv_offset, const BInfo &binfo, int tidx)
: Base(params.dqkv_ptr, params.qkv_stride_in_elts, binfo, tidx) {
this->ptr_ = reinterpret_cast<char *>(params.dqkv_ptr);
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / Base::THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % Base::THREADS_PER_ROW;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
// int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes +
// ((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
// int64_t row_offset = (int64_t)row * this->stride_in_bytes_ +
// ((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)row * this->stride_in_bytes_ +
((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
// Assemble the final pointer.
this->ptr_ += row_offset + col * Base::BYTES_PER_STG;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile
>
struct Gmem_summary_stats {
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
// The size of each element.
static constexpr int BYTES_PER_ELEMENT = 4;
static constexpr int BYTES_PER_MMA = (Cta_tile::THREADS_PER_WARP / 4) * 2 * BYTES_PER_ELEMENT;
static constexpr int ROWS = Cta_tile::M;
// Ctor.
template<typename Params>
inline __device__ Gmem_summary_stats(void *ptr, const Params &params, const int tidx)
: ptr_(reinterpret_cast<char *>(ptr)), tidx_(tidx) {
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The block index.
// size_t bidx = bidb * params.h + bidh;
uint32_t bidx = bidb * params.h + bidh;
// Extract the position in the warp.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// The distance between two blocks (in bytes).
// size_t block_stride_bytes = params.s * BYTES_PER_ELEMENT;
uint32_t block_stride_bytes = params.s * BYTES_PER_ELEMENT;
// Set store location for each thread at the beginning of the loop
ptr_row_ = ptr_ + bidx * block_stride_bytes;
ptr_ += bidx * block_stride_bytes + (lane / 4) * BYTES_PER_ELEMENT;
}
// Store data to global memory.
inline __device__ void store(const uint32_t (&data)[MMAS_M * 2]) {
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
if ((warp == 0) && (lane % 4 == 0)) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT, data[mi * 2 + 0]);
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT, data[mi * 2 + 1]);
}
}
}
// Store data to global memory.
inline __device__ void store_row(const uint32_t (&data)[MMAS_M], const int row) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::stg(ptr_row_ + mi * BYTES_PER_MMA + row * BYTES_PER_ELEMENT, data[mi]);
}
}
// Load from global memory.
inline __device__ void load(uint32_t (&data)[MMAS_M * 2]) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::ldg(data[mi * 2 + 0], ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
fmha::ldg(data[mi * 2 + 1], ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
}
}
// Load from global memory.
inline __device__ void load_next(uint32_t (&data)[MMAS_M * 2], int move_steps=1) {
char *ptr_next = ptr_ + move_steps * ROWS * BYTES_PER_ELEMENT;
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::ldg(data[mi * 2 + 0], ptr_next + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
fmha::ldg(data[mi * 2 + 1], ptr_next + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
}
}
// Store data to global memory.
template <int N>
inline __device__ void load_row(uint32_t (&data)[N], const int row[N]) {
#pragma unroll
for (int ni = 0; ni < N; ++ni) {
fmha::ldg(data[ni], ptr_row_ + row[ni] * BYTES_PER_ELEMENT);
}
}
// Move the pointer to the next location.
inline __device__ void move() {
ptr_ += ROWS * BYTES_PER_ELEMENT;
ptr_row_ += ROWS * BYTES_PER_ELEMENT;
}
// Move the pointer to the next location.
inline __device__ void move(const int steps) {
ptr_ += ROWS * BYTES_PER_ELEMENT * steps;
ptr_row_ += ROWS * BYTES_PER_ELEMENT * steps;
}
// The pointer.
char *ptr_;
char *ptr_row_;
const int tidx_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u>
struct FMHA_kernel_traits {
// The CTA description for the 1st GEMM.
using Cta_tile_p = fmha::Cta_tile_extd<STEP, S, D, WARPS_M, WARPS_N, 1>;
// The CTA description for the 2nd GEMM.
using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;
// Do we use one buffer for K and V.
static constexpr bool SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u;
// Do we keep K in registers.
static constexpr bool K_IN_REGS = (FLAGS & 0x10u) == 0u;
// Do we keep V in registers.
static constexpr bool V_IN_REGS = (FLAGS & 0x100u) == 0u;
// The global memory tile to load Q.
using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
// The shared memory tile to swizzle Q.
// using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;
using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The global memory tile to load K.
using Gmem_tile_k = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_B, S, D>;
// The shared memory tile to swizzle K.
using Smem_tile_k = fmha::Smem_tile_b<Cta_tile_p, fmha::Col>;
// The global memory tile to load V.
using Gmem_tile_v = fmha::Gmem_tile_qkv<Cta_tile_o, fmha::BITS_PER_ELEMENT_B, S, D>;
// The shared memory tile to swizzle V.
using Smem_tile_v = fmha::Smem_tile_v<Cta_tile_o>;
// The global memory tile to store O.
using Gmem_tile_o = fmha::Gmem_tile_o<Cta_tile_o>;
// The shared memory tile for O.
using Smem_tile_o = fmha::Smem_tile_o<Cta_tile_o>;;
// The global memory tile to load/store S.
using Gmem_tile_s = fmha::Gmem_tile_mma_s<Cta_tile_p>;
// The shared memory tile to transpose S.
using Smem_tile_st = fmha::Smem_tile_mma_transposed<Cta_tile_p>;
using Gmem_tile_do = fmha::Gmem_tile_dout<Cta_tile_p>;
using Gmem_tile_dot = fmha::Gmem_tile_dout<Cta_tile_p, fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_B, S, D> >;
// The global memory tile to store the softmax sum.
using Gmem_softmax_sum = fmha::Gmem_summary_stats<Cta_tile_p>;
// The shared memory tile to store dp sum.
using Smem_dp_sum = fmha::Smem_tile_dp_sum<Gmem_tile_q, 2>;
// Make sure the number of threads match.
static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, "");
// The number of threads.
static constexpr int THREADS = Cta_tile_p::THREADS_PER_CTA;
// Make sure the number of threads matches both CTAs.
static_assert(THREADS == Cta_tile_o::THREADS_PER_CTA, "");
// The amount of shared memory needed to load Q and K.
static constexpr int BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE;
// The extra amount of shared memory needed to load V.
static constexpr int BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE;
// The amount of shared memory needed for Q, K and V..
static constexpr int BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V;
// The amount of shared memory needed to load Q and store O.
static constexpr int BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE;
// The amount of shared memory needed for Q, K, V and O.
static constexpr int BYTES_PER_SMEM = fmha::MaxConstexpr(BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO);
// Make sure we have enough shared memory.
static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, "");
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace fmha {
template<typename Cta_tile, bool Is_causal=false>
struct Mask {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
template<typename BInfo>
__device__ Mask(const BInfo &blockInfo, int tidx, const int loop_step_idx_ = 0)
: actual_seqlen(blockInfo.actual_seqlen - loop_step_idx_ * Cta_tile::N)
, loop_step_idx(loop_step_idx_) {
const int warp = tidx / Cta_tile::THREADS_PER_WARP;
const int lane = tidx % Cta_tile::THREADS_PER_WARP;
static_assert(Cta_tile::WARPS_K == 1, "");
// find the warp in the Cta tile
const int warp_n = (warp / Cta_tile::WARPS_M);
const int warp_m = (warp % Cta_tile::WARPS_M);
// decompose warp into 8x4 tile
const int quad = lane / 4;
const int tid = (lane % 4) * 2;
row = warp_m * 16 + quad;
col = warp_n * 16 + tid;
}
inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {
// ii and jj iterate over the 2x4 fragment
// const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
const int current_row = row_offset + ii * 8;
const bool col_valid = current_col < actual_seqlen;
// const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen;
bool all_valid = Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("current_col=%d, current_row=%d, actual_seqlen=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen, col_valid, all_valid);
// }
return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
// return row_valid && col_valid;
}
//BERT Mask: if upper left is invalid, none are valid
inline __device__ bool any_valid(const int mi, const int ni) const {
return is_valid(mi, ni, 0, 0) || is_valid(mi, ni, 1, 0);
}
inline __device__ void load(const int it) {
row_offset = it * Cta_tile::M + row;
}
int row_offset;
int row;
int col;
const int loop_step_idx;
const int actual_seqlen;
};
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "utils.h"
#include <fmha/utils.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The description of the tile computed by this CTA.
typename Cta_tile,
// The number of rows in the 2D shared memory buffer.
int M_,
// The number of cols.
int N_,
// The size in bits of each element.
int BITS_PER_ELEMENT_,
// The number of bytes per STS.
int BYTES_PER_STS_ = 16,
// The number of buffers. (Used in multistage and double buffer cases.)
int BUFFERS_PER_TILE_ = 1,
// Do we enable the fast path for LDS.128 and friends.
int ENABLE_LDS_FAST_PATH_ = 0,
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
int ROWS_PER_XOR_PATTERN_ = 8,
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
int COLS_PER_XOR_PATTERN_ = 1,
// Use or not predicates
bool USE_PREDICATES_ = true
>
struct Smem_tile_without_skews {
// The size in bits of each element.
enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ };
// The size in bytes of a single STS.
enum { BYTES_PER_STS = BYTES_PER_STS_ };
// The number of elements per STS.
enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT };
// To support arbitrary N, we pad some values to a power-of-2.
enum { N_WITH_PADDING = Next_power_of_two<N_>::VALUE };
// The number of bytes per row without packing of rows.
enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 };
// The number of bytes per row -- we want at least 128B per row.
enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE };
// The number of rows in shared memory (two rows may be packed into a single one).
enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW };
// The number of threads per row.
enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS };
// The number of threads per row.
enum { THREADS_PER_ROW = Min<Cta_tile::THREADS_PER_CTA, THREADS_PER_ROW_UNBOUNDED>::VALUE };
// The number of STS per row.
enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS };
// It must be at least one.
static_assert(STS_PER_ROW >= 1, "");
// The number of rows written with a single STS.
enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
// Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)
static_assert(ROWS_PER_STS >= 1, "");
// The number of STS needed to store all rows.
enum { STS_PER_COL = Div_up<ROWS, ROWS_PER_STS>::VALUE };
// The number of STS in total.
enum { STS = STS_PER_COL * STS_PER_ROW };
// The size of one buffer in bytes in shared memory.
enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };
// The number of buffers.
enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ };
// The size in bytes of total buffers.
enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE };
// The boundary for smem_read_offset and smem_write_offset increment.
enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER };
// Do we enable the LDS.128 fast path?
enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ };
static_assert(ENABLE_LDS_FAST_PATH == 0);
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ };
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS };
// Use or not predicates
enum { USE_PREDICATES = USE_PREDICATES_ };
// The type of elements that are stored in shared memory by each thread.
using Store_type = typename Uint_from_size_in_bytes<BYTES_PER_STS>::Type;
// Ctor.
inline __device__ Smem_tile_without_skews(void *smem, int tidx)
: smem_(__nvvm_get_smem_pointer(smem)) {
// The row written by a thread. See doc/mma_smem_layout.xlsx.
int smem_write_row = tidx / THREADS_PER_ROW;
// The XOR pattern.
int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN;
// Compute the column and apply the XOR pattern.
int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor;
// The offset.
this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;
// TODO: Why not merge it with the read offset?
// this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);
// this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);
}
// Compute the store pointers.
template< int N >
inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) {
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
// Decompose the STS into row/col.
int row = ii / STS_PER_ROW;
int col = ii % STS_PER_ROW;
// Assemble the offset.
int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW;
// Take the column into account.
if( STS_PER_ROW > 1 ) {
offset += col*THREADS_PER_ROW*BYTES_PER_STS;
}
// Apply the XOR pattern if needed.
if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) {
const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;
offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;
}
// Assemble the final pointer :)
// ptrs[ii] = smem_ + offset + smem_write_buffer_;
// smem_write_buffer_ is already merged with smem_write_offset_
ptrs[ii] = smem_ + offset;
}
}
inline __device__ void debug_reset() {
for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {
for( int row = 0; row < ROWS; ++row ) {
for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {
if( threadIdx.x == 0 ) {
uint32_t val = 0x0;
sts(val, smem_ + row*BYTES_PER_ROW + col + buffer);
}
}
}
}
}
// Print the content of the tile (only for debug ;)).
inline __device__ void debug_print() const {
for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {
for( int row = 0; row < ROWS; ++row ) {
for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {
if( threadIdx.x == 0 ) {
uint32_t val;
lds(val, smem_ + row*BYTES_PER_ROW + col + buffer);
printf("block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n",
blockIdx.x,
blockIdx.y,
blockIdx.z,
smem_,
buffer,
row,
col,
val);
}
}
}
}
}
// Move the read offset to next buffer.
inline __device__ void move_to_next_read_buffer() {
// if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
// this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
// } else if( BUFFERS_PER_TILE > 1 ) {
// this->smem_read_buffer_ += BYTES_PER_BUFFER;
// }
if( BUFFERS_PER_TILE > 1 && smem_read_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
this->smem_read_offset_ -= BYTES_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_read_offset_ += BYTES_PER_BUFFER;
}
}
// Move the read offset to next buffer. TODO: Remove this member function!!!
inline __device__ void move_next_read_buffer() {
this->move_to_next_read_buffer();
}
// Move the read offset to next N buffer (circular-buffer).
inline __device__ void move_to_next_read_buffer(int N) {
if( BUFFERS_PER_TILE > 1 ) {
// this->smem_read_buffer_ += N * BYTES_PER_BUFFER;
// this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
this->smem_read_offset_ += N * BYTES_PER_BUFFER;
this->smem_read_offset_ -= smem_read_offset_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
}
}
// Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!
inline __device__ void move_next_read_buffer(int N) {
this->move_to_next_read_buffer(N);
}
// Move the write offset to next buffer.
inline __device__ void move_to_next_write_buffer() {
// if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
// this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
// } else if( BUFFERS_PER_TILE > 1 ) {
// this->smem_write_buffer_ += BYTES_PER_BUFFER;
// }
if( BUFFERS_PER_TILE > 1 && smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
this->smem_write_offset_ -= BYTES_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_write_offset_ += BYTES_PER_BUFFER;
}
}
// Move the write offset to next buffer. TODO: Remove that member function!
inline __device__ void move_next_write_buffer() {
this->move_to_next_write_buffer();
}
// Move the read offset.
inline __device__ void move_read_offset(int delta) {
this->smem_read_offset_ += delta;
}
// Move the write offset.
inline __device__ void move_write_offset(int delta) {
this->smem_write_offset_ += delta;
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) {
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
sts(smem_ptrs, data);
}
// Store to the tile in shared memory.
template< int N, int M >
inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) {
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
sts(smem_ptrs, data, preds);
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) {
this->store(data, preds);
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) {
uint32_t tmp[1] = { preds };
this->store(gmem_ptrs, tmp);
}
// The shared memory pointer.
const uint32_t smem_;
// The read offset. Reserve 4 offsets if needed.
int smem_read_offset_;
// The write offset.
int smem_write_offset_;
// The buffer base offset for read.
// int smem_read_buffer_;
// The buffer base offset for write.
// int smem_write_buffer_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The layout of the tile.
typename Layout,
// The size of the STS.
int BYTES_PER_STS = 16,
// The number of buffers per tile.
int BUFFERS_PER_TILE = 1,
// Use or not predicates
bool USE_PREDICATES = true
>
struct Smem_tile_a {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K, int MMAS_K_WITH_PADDING >
struct Compute_reset_mask {
// The potential mask.
enum { HALF = MMAS_K_WITH_PADDING / 2 };
// The remainder.
enum { MOD = MMAS_K % HALF };
// The final value.
enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask<MOD, HALF>::VALUE };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K_WITH_PADDING >
struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> {
enum { VALUE = 0 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K >
struct Compute_reset_mask<MMAS_K, MMAS_K> {
enum { VALUE = MMAS_K - 1 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_a {
// The size in bits.
enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A };
// The number of rows.
enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a<N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a<Cta_tile::K>::VALUE
>
struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
Cta_tile::M,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_A,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::M,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_A,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1>;
// The fragment.
using Fragment = Fragment_a<Row>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// Ctor.
inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) {
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_M == 1);
static_assert(WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_K == 1);
static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);
// The row and column read by the thread.
int smem_read_row = (tidx & 0x0f);
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
smem_read_col ^= (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) {
#pragma unroll
for( int mi = 0; mi < Mma_tile::MMAS_M; ++mi ) {
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// Load using LDSM.M88.4.
uint4 tmp;
// ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset);
// Store the value into the fragment.
a[mi].reg(0) = tmp.x;
a[mi].reg(1) = tmp.y;
a[mi].reg(2) = tmp.z;
a[mi].reg(3) = tmp.w;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) {
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) {
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) {
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset() {
// The number of MMAs in the K dimension.
enum { MMAS_K = Mma_tile::MMAS_K };
// The number of MMAs in the K dimension when we include padding.
enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };
// Assemble the mask.
enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_a<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_row_a<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_row_a<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The layout of the tile.
typename Layout,
// The size of the STS.
int BYTES_PER_STS = 16,
// The number of buffers per tile.
int BUFFERS_PER_TILE = 1,
// Use or not predicates
bool USE_PREDICATES = true
>
struct Smem_tile_b {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_b {
// The size in bits.
enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B };
// The number of rows.
enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b<N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b<Cta_tile::K>::VALUE
>
struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
Cta_tile::N,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::N,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1>;
// The fragment.
using Fragment = Fragment_b< Col>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding< Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// The number of STS per thread
enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };
// The number of STS per thread must be at least 1.
enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };
// Ctor.
inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) {
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);
static_assert(WARPS_M == 1);
static_assert(WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_K == 1);
// The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
// The divisor for the warps.
const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
// The row and column read by the thread.
int smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA +
(tidx & 0x07) +
(tidx & 0x10) / 2;
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
smem_read_col ^= (tidx & 0x08) / 8;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// Load using LDSM.M88.4.
uint4 tmp;
// ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset);
// Store the value into the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) {
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) {
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) {
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset() {
// The number of MMAs in the K dimension.
enum { MMAS_K = Mma_tile::MMAS_K };
// The number of MMAs in the K dimension when we include padding.
enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };
// Assemble the mask.
enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_b< Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE >
: public Smem_tile_col_b<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_col_b< Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b< N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b<Cta_tile::N>::VALUE,
// How many cols to use for the XOR pattern to avoid bank conflicts?
int COLS_PER_XOR_PATTERN_ = 1
>
struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
Cta_tile::K,
Cta_tile::N,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
COLS_PER_XOR_PATTERN_> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::K,
Cta_tile::N,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
COLS_PER_XOR_PATTERN_>;
// The fragment.
using Fragment = Fragment_b<Row>;
// Can we use LDSM? No if the data type is 32-bit large.
enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 };
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 };
// The number of elements per LDS.
enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B };
// The number of STS per thread
enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };
// The number of STS per thread must be at least 1.
enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };
// Ctor.
inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) {
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_K == 1);
static_assert(WARPS_M == 4 || WARPS_M == 8);
static_assert(WARPS_N == 1);
// The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
const int WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
static_assert(USE_LDSMT);
static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);
// The row/col read by the thread.
int smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 +
(tidx & 0x07) + (tidx & 0x08);
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;
// Fill zeroes for group conv
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// The size of each element in bits.
const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( BYTES_PER_MMA_PER_CTA >= 128 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
} else if( BYTES_PER_MMA_PER_CTA == 64 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
Mma_tile::MMAS_N % 2 == 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
// The size of each element in bits.
const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Prepare the offset.
int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW_BEFORE_PACKING;
if ( BYTES_PER_MMA_PER_CTA == 32 ) {
offset += this->smem_read_offset_;
} else if ( BYTES_PER_MMA_PER_CTA == 64 ) {
offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2;
} else {
offset += this->smem_read_offset_ + (ni ) * BYTES_PER_MMA_PER_CTA;
}
// Load the data using LDSM.MT88.2.
// uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;
uint32_t ptr = this->smem_ + offset;
uint4 tmp;
if( USE_LDSMT ) {
ldsmt(tmp, ptr);
} else {
lds(tmp.x, (ptr ) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING);
lds(tmp.y, (ptr ) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING);
lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING);
lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING);
}
// Store those values in the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if( BYTES_PER_MMA_PER_CTA >= 128 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
} else if( BYTES_PER_MMA_PER_CTA == 64 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
Mma_tile::MMAS_N % 2 == 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_b<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_row_b<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_row_b<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, Rows_per_xor_pattern_col_b<Cta_tile::N>::VALUE, 1> {
// The base class.
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, Rows_per_xor_pattern_col_b<Cta_tile::N>::VALUE, 1>;
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The fragment.
using Fragment = Fragment_b< fmha::Col>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// Ctor.
inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) {
// The row/col read by the thread.
int read_row, read_col;
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f);
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
read_col = ((read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
read_col ^= (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW_BEFORE_PACKING + read_col * BYTES_PER_LDS;
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Jump by 16 * #warps row.
int row = ki * 16 * Cta_tile::WARPS_K;
// Load the data using LDSM.MT88.2.
uint4 tmp;
fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW_BEFORE_PACKING);
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if( Mma_tile::MMAS_N == 1 ) {
// noop
} else if( Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
} else if( Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if (Mma_tile::MMAS_N == 8) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2));
} else {
assert(false); // Not implemented!
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile>
struct Smem_tile_o {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The accumulators.
using Accumulator = fmha::Fragment_accumulator;
// The accumulators.
using Data_type = typename Accumulator::Data_type;
// The size of each element.
static constexpr int BYTES_PER_ELEMENT = sizeof(Data_type);
// The size of each STS.
static constexpr int BYTES_PER_STS = 8;
// The size of each row in shared memory.
static constexpr int BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT;
// The size of each LDS.
static constexpr int BYTES_PER_LDS = 16;
static constexpr int THREADS_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT / BYTES_PER_LDS;
// The number of rows.
static constexpr int ROWS = Cta_tile::M;
// The number of "rows" to process per loop iteration (in the "epilogue").
static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA;
// The number of outer loops.
static constexpr int LOOPS = ROWS / ROWS_PER_LOOP;
// Make sure it matches our expectations.
static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, "");
// The number of rows loaded per LDS.
static constexpr int ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
// Do we have to guard against partial writes/reads.
static constexpr bool HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0;
// The total number of LDS per loop.
static constexpr int LDS_PER_LOOP = fmha::DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_LDS);
// The amount of shared memory.
static constexpr int BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW;
// The write pointer.
uint32_t smem_write_, smem_read_;
// Is the thread active for the last LDS of the series?
int is_active_for_last_lds_;
// static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K);
static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, "");
// Ctor.
inline __device__ Smem_tile_o(void *smem, int tidx) {
// Get a 32-bit value for the shared memory address.
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
static_assert(Cta_tile::N == 16 || Cta_tile::N == 32 || Cta_tile::N == 64 || Cta_tile::N == 128);
int write_row = (tidx & 0x1c) / 4;
const int lane = tidx % 32;
const int warp = tidx / 32;
constexpr int ELEMENTS_PER_STS = BYTES_PER_STS / BYTES_PER_ELEMENT;
constexpr int STS_PER_WARP = 16 * Mma_tile::MMAS_N / ELEMENTS_PER_STS;
int write_col = warp * STS_PER_WARP + lane % STS_PER_WARP;
// Assemble the write pointer.
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
// The element read by each thread.
int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
// Take the XOR pattern into account for the column.
// read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8)));
read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8))));
// Assemble the read pointer.
this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// Is that thread active on the last LDS?
if( HAS_INCOMPLETE_LDS ) {
this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M;
}
}
// Load the output fragments.
template <bool zero_init=true>
inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {
#pragma unroll
for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {
// Load the elements before the reduction (split-K).
uint4 tmp[Cta_tile::WARPS_K];
#pragma unroll
for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {
fmha::lds(tmp[jj], this->smem_read_ + imm);
}
}
// Perform the reduction.
out[ii] = zero_init ? tmp[0] : fmha::fadd4(out[ii], tmp[0]);
#pragma unroll
for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {
out[ii] = fmha::fadd4(out[ii], tmp[jj]);
}
}
}
// Store the accumulators.
template <int M, int N>
inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {
static constexpr int M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// The number of MMAs that are stored per loop iteration.
static constexpr int MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS;
// Store 1st column of the different MMAs.
#pragma unroll
for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
// Precompute the immediates to jump between rows.
int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
uint2 tmp0, tmp1;
tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0);
tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1);
tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2);
tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3);
// Store.
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// Swizzle the write pointer using a XOR of 16B.
this->smem_write_ ^= 32;
// Store 2nd column of the different MMAs.
#pragma unroll
for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
// Precompute the immediates to jump between rows.
int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
uint2 tmp0, tmp1;
tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4);
tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5);
tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6);
tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7);
// Store.
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_mma {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
using Fragment = fmha::Fragment_a<fmha::Col>;
enum { COLS = Cta_tile::N };
enum { BYTES_PER_ELT = 2 };
enum { BYTES_PER_STS = 4 };
enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO
enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
enum { WARPS_K = Cta_tile::WARPS_K };
static_assert(WARPS_K == 1);
inline __device__ Smem_tile_mma(char *smem, int tidx) {
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
int write_col, write_row;
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {
write_row = (tidx & 0x1c) / 4;
write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
} else {
write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;
write_col = (tidx & 0x03);
}
// TODO [TD] Only works for, D=16, D=32 or D=64
write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : 0x07))) * 4;
// write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
}
template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) {
static_assert(COLS == Cta_tile::N);
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
// fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
// fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
// offset ^= 4 * BYTES_PER_STS;
// fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
// fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
// size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * BYTES_PER_STS;
fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
}
}
}
template<typename Fragment, int M, int N>
inline __device__ void store(const Fragment (&frag)[N][M]) {
static_assert(COLS == Cta_tile::N);
uint4 regs[M][N];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// Need to transpose ref(1) and reg(2) here since when we load it we transpose again.
regs[mi][ni] = make_uint4(frag[ni][mi].reg(0), frag[ni][mi].reg(2),
frag[ni][mi].reg(1), frag[ni][mi].reg(3));
}
}
this->store(regs);
}
// uint32_t smem_;
// uint32_t write_offset_;
uint32_t smem_write_;
};
template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>
struct Smem_tile_mma_transposed : public Base {
enum { BYTES_PER_LDS = 16 };
enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };
enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
using Fragment = typename Base::Fragment;
inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) {
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
int read_row, read_col;
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;
read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07)));
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
template<int M, int N>
inline __device__ void load(Fragment (&frag)[M][N]) {
static_assert(Base::COLS == Cta_tile::N);
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
// fmha::ldsmt(dst, this->smem_ + offset);
// size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::ldsmt(dst, offset);
frag[mi][ni].reg(0) = dst.x;
frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major!
frag[mi][ni].reg(2) = dst.y;
frag[mi][ni].reg(3) = dst.w;
}
}
}
// uint32_t read_offset_;
uint32_t smem_read_;
};
template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>
struct Smem_tile_mma_epilogue : public Base {
enum { BYTES_PER_LDS = 16 };
enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };
enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS };
static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW);
enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS };
static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M);
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
using Acc = fmha::Fragment_accumulator;
inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) {
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
const int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07)));
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
inline __device__ void load(uint4 (&data)[NUM_LDS]) {
for( int ii = 0; ii < NUM_LDS; ii++ ) {
// size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
// fmha::lds(data[ii], this->smem_ + offset);
// size_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
uint32_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
fmha::lds(data[ii], offset);
}
}
template<int M, int N>
inline __device__ void store(const Acc (&acc)[M][N]){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// 1st row - 4 elements per row.
float tmp00 = acc[mi][ni].elt(0);
float tmp01 = acc[mi][ni].elt(1);
float tmp02 = acc[mi][ni].elt(4);
float tmp03 = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc[mi][ni].elt(2);
float tmp11 = acc[mi][ni].elt(3);
float tmp12 = acc[mi][ni].elt(6);
float tmp13 = acc[mi][ni].elt(7);
uint32_t x = fmha::float2_to_half2(tmp00, tmp01);
uint32_t y = fmha::float2_to_half2(tmp02, tmp03);
uint32_t z = fmha::float2_to_half2(tmp10, tmp11);
uint32_t w = fmha::float2_to_half2(tmp12, tmp13);
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);
// offset ^= 4 * Base::BYTES_PER_STS;
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
// size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
fmha::sts(offset + 0 * BYTES_PER_ROW, x);
fmha::sts(offset + 8 * BYTES_PER_ROW, z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(offset + 0 * BYTES_PER_ROW, y);
fmha::sts(offset + 8 * BYTES_PER_ROW, w);
}
}
}
template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) {
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
}
}
}
// uint32_t read_offset_;
uint32_t smem_read_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_transpose {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
using Fragment_write = fmha::Fragment_b<fmha::Col>;
using Fragment_read = fmha::Fragment_b<fmha::Col>;
enum { COLS = Cta_tile::N };
enum { BYTES_PER_ELT = 2 };
enum { BYTES_PER_STS = 4 };
enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO
enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };
enum { BYTES_PER_LDS = 16 };
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
enum { WARPS_K = Cta_tile::WARPS_K };
static_assert(WARPS_K == 1);
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
inline __device__ Smem_tile_transpose(char *smem, int tidx) {
smem_ = __nvvm_get_smem_pointer(smem);
// uint32_t smem_ = __nvvm_get_smem_pointer(smem);
int write_col, write_row;
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {
write_row = (tidx & 0x1c) / 4;
write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
} else {
write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;
write_col = (tidx & 0x03);
}
write_col ^= (write_row & 0x07) * 4;
write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
// smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
int read_row, read_col;
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;
read_col ^= (read_row & 0x07);
read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
template<int M, int N>
inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
}
}
template<int N>
inline __device__ void load(Fragment_read (&frag_r)[N]) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
frag_r[ni].reg(0) = dst.x;
frag_r[ni].reg(1) = dst.y; // Fragment B regs col major!
frag_r[ni].reg(2) = dst.z;
frag_r[ni].reg(3) = dst.w;
}
}
template<int M, int N>
inline __device__ void transpose(const Fragment_write (&frag_w)[M][N], Fragment_read (&frag_r)[M], int mi) {
static_assert(COLS == Cta_tile::N);
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
}
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
frag_r[ni].reg(0) = dst.x;
frag_r[ni].reg(1) = dst.y; // Fragment B regs col major!
frag_r[ni].reg(2) = dst.z;
frag_r[ni].reg(3) = dst.w;
}
}
uint32_t smem_;
uint32_t write_offset_;
uint32_t read_offset_;
// uint32_t smem_write_;
// uint32_t smem_read_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Gmem_tile,
// The number of buffers. (Used in multistage and double buffer cases.)
int BUFFERS_PER_TILE_ = 1
>
struct Smem_tile_dp_sum {
using Cta_tile = typename Gmem_tile::Cta_tile;
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The size of each element.
static constexpr int BYTES_PER_ELEMENT = 4;
static constexpr int ROWS = Gmem_tile::ROWS;
static constexpr int THREADS_PER_ROW = Gmem_tile::THREADS_PER_ROW;
static constexpr int MMAS_M = Mma_tile::MMAS_M;
static constexpr int ROWS_PER_LDG = Gmem_tile::ROWS_PER_LDG;
static constexpr int LDGS = Gmem_tile::LDGS;
static constexpr int ROWS_PER_MMA = Mma_tile::M_PER_MMA;
// The size of one buffer in bytes in shared memory.
static constexpr int BYTES_PER_BUFFER = ROWS * BYTES_PER_ELEMENT;
// The number of buffers.
static constexpr int BUFFERS_PER_TILE = BUFFERS_PER_TILE_;
// The size in bytes of total buffers.
static constexpr int BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE;
// The boundary for smem_read_offset and smem_write_offset increment.
static constexpr int ROWS_PER_TILE_INC_BOUNDARY = ROWS * BUFFERS_PER_TILE - ROWS;
inline __device__ Smem_tile_dp_sum(float *smem, const int tidx)
: smem_(smem), smem_read_buffer_(smem), smem_write_buffer_(smem), tidx_(tidx) {
}
// Move the read offset to next buffer.
inline __device__ void move_to_next_read_buffer() {
if( BUFFERS_PER_TILE > 1 && (smem_read_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY ) {
this->smem_read_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_read_buffer_ += ROWS;
}
}
// Move the write offset to next buffer.
inline __device__ void move_to_next_write_buffer() {
if( BUFFERS_PER_TILE > 1 && (smem_write_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY ) {
this->smem_write_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_write_buffer_ += ROWS;
}
}
inline __device__ void store(const float (&sum)[LDGS]) {
if (tidx_ % THREADS_PER_ROW == 0) {
int row = tidx_ / THREADS_PER_ROW;
#pragma unroll
for (int i = 0; i < LDGS; ++i) {
if (row + i * ROWS_PER_LDG < ROWS) {
smem_write_buffer_[row + i * ROWS_PER_LDG] = sum[i];
}
}
}
}
inline __device__ void store(const float sum, const int buffer_idx) {
float *smem_write = smem_ + buffer_idx * ROWS;
int row = tidx_ / THREADS_PER_ROW;
if ((row < ROWS) && (tidx_ % THREADS_PER_ROW == 0)) {
smem_write[row] = sum;
}
}
inline __device__ void store(const float (&sum)[LDGS], const int buffer_idx) {
float *smem_write = smem_ + buffer_idx * ROWS;
if (tidx_ % THREADS_PER_ROW == 0) {
int row = tidx_ / THREADS_PER_ROW;
#pragma unroll
for (int i = 0; i < LDGS; ++i) {
if (row + i * ROWS_PER_LDG < ROWS) {
smem_write[row + i * ROWS_PER_LDG] = sum[i];
}
}
}
}
inline __device__ void store_pair(const float (&sum)[MMAS_M * 2], const int buffer_idx) {
float *smem_write = smem_ + buffer_idx * ROWS;
// Extract the position in the warp.
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
int row = lane / 4;
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0];
smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1];
}
}
template<int N>
inline __device__ void load(float (&sum)[N], const int (&row)[N]) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
sum[ni] = smem_read_buffer_[row[ni]];
}
}
template<int N>
inline __device__ void load(float (&sum)[N], const int (&row)[N], const int buffer_idx) {
float *smem_read = smem_ + buffer_idx * ROWS;
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
sum[ni] = smem_read[row[ni]];
}
}
static inline __device__ float reduce_warp(float sum) {
fmha::SumOp<float> sum_op;
return fmha::Allreduce<THREADS_PER_ROW>::run(sum, sum_op);
}
const int tidx_;
float * const smem_;
float *smem_read_buffer_;
float *smem_write_buffer_;
};
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cmath>
#include <cuda_fp16.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Sum_ {
static constexpr bool IS_SUM = true;
static inline __device__ float apply(float x, float y) {
return x + y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Max_ {
static constexpr bool IS_SUM = false;
static inline __device__ float apply(float x, float y) {
return x > y ? x : y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float apply_exp_(float x, float max) {
return __expf(x - max);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __half2 apply_exp_(__half2 x, __half2 max) {
return h2exp(__hsub2(x, max));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float apply_exp2_(float x, float max) {
return exp2f(x - max);
// With fast-math, this produces the same PTX instruction as the assembly below
// float diff = x - max;
// float res;
// asm ("ex2.approx.ftz.f32 %0, %1;\n\t" : "=f"(res) : "f"(diff));
// return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __half2 apply_exp2_(__half2 x, __half2 max) {
return h2exp2(__hsub2(x, max));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int COLS, bool half> struct ReadType {};
template<> struct ReadType<4, false> { using T = float;};
template<> struct ReadType<8, false> { using T = float2;};
template<> struct ReadType<4, true> { using T = __half2;};
template<> struct ReadType<8, true> { using T = float2;};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile, typename Kernel_traits>
struct Smem_tile_reduce {
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
static constexpr int MMAS_N = Mma_tile::MMAS_N;
static constexpr int WARPS_M = Cta_tile::WARPS_M;
static constexpr int WARPS_N = Cta_tile::WARPS_N;
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
static constexpr int COLS = WARPS_N;
static_assert(COLS == 4 || COLS == 8);
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
static constexpr int ELTS_PER_TILE = ROWS * COLS;
static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW;
// TD [2022-05-02]: No longer true if head_dim != 64
// static_assert(THREADS_PER_GROUP == 16); // DEBUG
static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP;
static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
static_assert(LOOPS == 1);
using read_t = typename ReadType<COLS, /*half=*/false>::T;
using read_half_t = typename ReadType<COLS, /*half=*/true>::T;
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
int lane = tidx % 32;
int warp = tidx / 32;
int warp_m = warp % WARPS_M;
int warp_n = warp / WARPS_M;
qid_ = lane % 4;
int qp = lane / 4;
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
// This won't affect reading as we assume commutative reduction ops.
const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
smem_read_row_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qid_];
}
__device__ inline void store(float (&frag)[2 * MMAS_M]) {
if( qid_ == 0 ) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * WARPS_N;
smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];
smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];
}
}
}
__device__ inline void store(__half2 (&frag)[MMAS_M]) {
__half2 *smem_write_half_ = reinterpret_cast<__half2 *>(smem_write_);
if( qid_ == 0 ) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * WARPS_N;
smem_write_half_[offset + 0 * 8 * WARPS_N] = frag[mi];
}
}
}
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];
frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];
}
}
__device__ inline void load(read_half_t (&frag)[MMAS_M]) {
read_half_t *smem_read_half_ = reinterpret_cast<read_half_t *>(smem_read_);
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi] = smem_read_half_[offset + 0 * 8 * 4];
}
}
__device__ inline void load_row(read_t (&frag)[MMAS_M], int row) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi] = smem_read_row_[offset + 0 * 8 * 4 + row * 4];
}
}
int qid_;
float *smem_write_;
read_t *smem_read_;
read_t *smem_read_row_;
};
template<typename Cta_tile, typename Kernel_traits>
struct Softmax_base {
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
static constexpr int MMAS_N = Mma_tile::MMAS_N;
// The number of groups of warp such that we have at most 4 warps writing consecutive elements.
static constexpr int GROUPS = fmha::DivUpConstexpr(Cta_tile::WARPS_N, 4);
// The number of elements that we are going to store per row.
static constexpr int ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS;
// The number of rows.
static constexpr int ROWS = Cta_tile::M * GROUPS;
// The total number of elements.
static constexpr int ELEMENTS = ROWS * ELEMENTS_PER_ROW;
// Ctor.
template<typename Params>
inline __device__ Softmax_base(const Params &params, void *smem, int tidx)
: // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
smem_(reinterpret_cast<float *>(smem)), tidx_(tidx) {
// Move to the 1st mask loaded by the thread+ tidx;
// packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);
// Extract the position in the warp.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// Decompose the warp index into M and N.
int warp_m = warp % Cta_tile::WARPS_M;
int warp_n = warp / Cta_tile::WARPS_M;
// Decompose the warp-n index into group/position-inside-the-group.
int warp_g = warp_n / ELEMENTS_PER_ROW;
int warp_i = warp_n % ELEMENTS_PER_ROW;
// The location written by the threads.
int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;
int write_col = warp_i;
// Assemble the write pointer.
smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];
// Assemble the read pointer.
smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
}
template<bool zero=false, typename Mask>
inline __device__ void apply_mask(const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ii = 0; ii < 2; ++ii ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
#pragma unroll
for( int jj = 0; jj < 4; ++jj ) {
if( !mask.is_valid(mi, ni, ii, jj) ) {
elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY;
}
}
}
}
}
}
// Apply the exp to all the elements.
template <bool max_in_base2=false, bool elt_in_base2=false>
inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
constexpr float kLog2e = M_LOG2E;
const float max_base2 = max_in_base2 ? max[mi] : max[mi] * kLog2e;
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
// elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);
elt_[mi][ni] = apply_exp2_(elt_in_base2 ? elt_[mi][ni] : elt_[mi][ni] * kLog2e,
max_base2);
}
}
}
// Apply the exp to all the elements.
template <bool scale_max=true>
inline __device__ void scale_apply_exp(const float (&max)[MMAS_M * 2], const float scale_) {
const float max_scale = scale_max ? scale_ * M_LOG2E : M_LOG2E;
const float scale = scale_ * M_LOG2E;
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
const float max_scaled = max[mi] * max_scale;
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * scale, max_scaled);
}
}
}
// Apply the exp to all the elements.
inline __device__ void apply_exp(const __half2 (&max)[MMAS_M]) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
constexpr float kLog2e = M_LOG2E;
const float2 max_f = __half22float2(max[mi]);
const float max0_log2e = max_f.x * kLog2e, max1_log2e = max_f.y * kLog2e;
#pragma unroll
for (int ni = 0; ni < MMAS_N * 4; ++ni) {
float2 elt = __half22float2(elt_half_[mi][ni]);
elt_[mi * 2 + 0][ni] = apply_exp2_(elt.x * kLog2e, max0_log2e);
elt_[mi * 2 + 1][ni] = apply_exp2_(elt.y * kLog2e, max1_log2e);
// __half2 out = apply_exp_(elt_half_[mi][ni], max[mi]);
// float2 outf = __half22float2(out);
// elt_[mi * 2 + 0][ni] = outf.x;
// elt_[mi * 2 + 1][ni] = outf.y;
}
}
}
// Apply the exp to all the elements.
template <bool max_in_base2=false>
inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
constexpr float kLog2e = M_LOG2E;
const float max_base2 = max_in_base2 ? max[ni] : max[ni] * kLog2e;
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
}
}
}
// inline __device__ void apply_exp_col(const float (&max)[MMAS_N]) {
// constexpr float kLog2e = M_LOG2E;
// #pragma unroll
// for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
// float max_base2 = max_in_base2 ? max[ni / 4] : max[ni / 4] * kLog2e;
// max_base2 = __shfl_sync(0xffffffff, max_base2, (ni % 4) * 8 + threadIdx.x % 8);
// #pragma unroll
// for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
// elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
// }
// }
// }
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout(Philox &ph, uint32_t p_dropout_in_uint) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; mi++ ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni++ ) {
uint4 tmp = ph();
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// }
elt_[mi][4 * ni + 0] =
encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * ni + 0]);
elt_[mi][4 * ni + 1] =
encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * ni + 1]);
elt_[mi][4 * ni + 2] =
encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * ni + 2]);
elt_[mi][4 * ni + 3] =
encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * ni + 3]);
}
}
}
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout(Philox &ph0, Philox &ph1, uint32_t p_dropout_in_uint) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; mi++ ) {
static_assert(MMAS_N % 2 == 0);
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni += 2 ) {
uint4 tmp = ph0();
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph0, Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// }
elt_[mi][4 * ni + 0] =
encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * ni + 0]);
elt_[mi][4 * ni + 1] =
encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * ni + 1]);
elt_[mi][4 * ni + 2] =
encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * ni + 2]);
elt_[mi][4 * ni + 3] =
encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * ni + 3]);
tmp = ph1();
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph1, Philox: %u, %u, %u, %u\n", ni + 1, tmp.x, tmp.y, tmp.z, tmp.w);
// }
elt_[mi][4 * (ni + 1) + 0] =
encode_dropout(tmp.x <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 0]);
elt_[mi][4 * (ni + 1) + 1] =
encode_dropout(tmp.y <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 1]);
elt_[mi][4 * (ni + 1) + 2] =
encode_dropout(tmp.z <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 2]);
elt_[mi][4 * (ni + 1) + 3] =
encode_dropout(tmp.w <= p_dropout_in_uint, elt_[mi][4 * (ni + 1) + 3]);
}
}
}
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni++ ) {
uint16_t tmp[8];
fmha::uint4_to_ushort8(ph(), tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
elt_[mi * 2 + ii][4 * ni + jj] =
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
}
}
}
}
}
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout_16bits(Philox &ph0, Philox &ph1, uint16_t p_dropout_in_uint16_t) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
static_assert(MMAS_N % 2 == 0);
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni += 2 ) {
uint16_t tmp[8];
fmha::uint4_to_ushort8(ph0(), tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
elt_[mi * 2 + ii][4 * ni + jj] =
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
}
}
fmha::uint4_to_ushort8(ph1(), tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
elt_[mi * 2 + ii][4 * (ni + 1) + jj] =
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * (ni + 1) + jj]);
}
}
}
}
}
// Scale all the elements.
inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
float inv_sum[MMAS_M * 2];
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];
}
// Update the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] *= inv_sum[mi];
}
}
}
// Subtract all elements by dp_sum
inline __device__ void subtract_dp_sum(const float (&dp_sum)[MMAS_M * 2]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] -= dp_sum[mi];
}
}
}
// The pointer to the mask.
const char *packed_mask_ptr_;
// Shared memory for the CTA-wide reduction.
float *smem_, *smem_write_, *smem_read_;
// The current thread index.
int tidx_;
// The elements.
float elt_[MMAS_M * 2][MMAS_N * 4];
__half2 elt_half_[MMAS_M][MMAS_N * 4];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile, typename Kernel_traits>
struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
// The base class.
using Base = Softmax_base<Cta_tile, Kernel_traits>;
// The fragment.
using Fragment_a = fmha::Fragment_a<fmha::Row>;
static_assert(Fragment_a::NUM_REGS == 4);
static constexpr int WARPS_M = Cta_tile::WARPS_M;
static constexpr int WARPS_N = Cta_tile::WARPS_N;
// The MMAs.
static constexpr int MMAS_M = Base::MMAS_M;
static constexpr int MMAS_N = Base::MMAS_N;
// The accumulators.
using Accumulator = fmha::Fragment_accumulator;
using Accumulator_out = Fragment<uint16_t, 8>;
static_assert(Accumulator_out::NUM_REGS == 4);
static_assert(std::is_same<Accumulator::Data_type, float>::value);
using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
// Ctor.
template<typename Params>
inline __device__ Softmax(const Params &params, void *smem, int tidx)
: Base(params, smem, tidx)
, params_scale_bmm1_(params.scale_bmm1)
, smem_sum_(static_cast<float*>(smem), tidx)
, smem_max_(static_cast<float*>(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {
}
// Pack the data to a fragment for the next GEMM.
template<int K, int M>
inline __device__ void pack(Fragment_a (&dst)[K][M]) const {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ki = 0; ki < K; ++ki ) {
// 1st row - 4 elements per row.
float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];
float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];
float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2];
float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3];
// 2nd row - 4 elements per row.
float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0];
float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1];
float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2];
float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];
// Pack to 4 registers.
dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);
dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);
dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);
dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);
}
}
}
// Scale FP32 fragments
inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
const float scalef = reinterpret_cast<const float &>(this->params_scale_bmm1_);
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef;
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef;
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef;
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef;
// 2nd row - 4 elements per row.
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef;
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef;
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef;
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef;
}
}
}
// Scale FP32 fragments
inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4);
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2);
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);
}
}
}
// Scale FP32 fragments
template <typename Mask>
inline __device__ void unpack_noscale_half_and_apply_mask(const Accumulator (&acc)[MMAS_M][MMAS_N],
const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
float tmp[2][4];
// 1st row - 4 elements per row.
tmp[0][0] = mask.is_valid(mi, ni, 0, 0) ? acc[mi][ni].elt(0) : -INFINITY;
tmp[0][1] = mask.is_valid(mi, ni, 0, 1) ? acc[mi][ni].elt(1) : -INFINITY;
tmp[0][2] = mask.is_valid(mi, ni, 0, 2) ? acc[mi][ni].elt(4) : -INFINITY;
tmp[0][3] = mask.is_valid(mi, ni, 0, 3) ? acc[mi][ni].elt(5) : -INFINITY;
// 2nd row - 4 elements per row.
tmp[1][0] = mask.is_valid(mi, ni, 1, 0) ? acc[mi][ni].elt(2) : -INFINITY;
tmp[1][1] = mask.is_valid(mi, ni, 1, 1) ? acc[mi][ni].elt(3) : -INFINITY;
tmp[1][2] = mask.is_valid(mi, ni, 1, 2) ? acc[mi][ni].elt(6) : -INFINITY;
tmp[1][3] = mask.is_valid(mi, ni, 1, 3) ? acc[mi][ni].elt(7) : -INFINITY;
this->elt_half_[mi][4 * ni + 0] = __floats2half2_rn(tmp[0][0], tmp[1][0]);
this->elt_half_[mi][4 * ni + 1] = __floats2half2_rn(tmp[0][1], tmp[1][1]);
this->elt_half_[mi][4 * ni + 2] = __floats2half2_rn(tmp[0][2], tmp[1][2]);
this->elt_half_[mi][4 * ni + 3] = __floats2half2_rn(tmp[0][3], tmp[1][3]);
}
}
}
template<bool zero_init=true, typename Operator>
__device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) {
#pragma unroll
for( int mi = 0; mi < 2 * MMAS_M; mi++ ) {
frag[mi] = zero_init ? this->elt_[mi][0] : op(frag[mi], this->elt_[mi][0]);
#pragma unroll
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
frag[mi] = op(frag[mi], this->elt_[mi][ni]);
}
}
}
template<typename Operator>
__device__ inline void thread_reduce_(__half2 (&frag)[MMAS_M], Operator &op) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
frag[mi] = this->elt_half_[mi][0];
#pragma unroll
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
frag[mi] = op(frag[mi], this->elt_half_[mi][ni]);
}
}
}
template<bool zero_init=true, typename Operator>
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
thread_reduce_<zero_init>(frag, op);
quad_reduce(frag, frag, op);
smem_red.store(frag);
__syncthreads();
typename Smem_tile_red::read_t tmp[2 * MMAS_M];
smem_red.load(tmp);
quad_allreduce(frag, tmp, op);
}
template<typename Operator>
__device__ inline void reduce_(__half2 (&frag)[MMAS_M], Operator &op, Smem_tile_red & smem_red) {
thread_reduce_(frag, op);
quad_reduce(frag, frag, op);
smem_red.store(frag);
__syncthreads();
typename Smem_tile_red::read_half_t tmp[MMAS_M];
smem_red.load(tmp);
quad_allreduce(frag, tmp, op);
}
template<bool zero_init=true>
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
MaxOp<float> max;
reduce_<zero_init>(frag, max, smem_max_);
}
__device__ inline void reduce_max(__half2 (&frag)[MMAS_M]){
MaxOp<__half2> max;
reduce_(frag, max, smem_max_);
}
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
SumOp<float> sum;
reduce_(frag, sum, smem_sum_);
}
template<bool zero_init=true>
__device__ inline void reduce_sum_before_sync_(float (&frag)[2 * MMAS_M]){
SumOp<float> sum;
thread_reduce_<zero_init>(frag, sum);
quad_reduce(frag, frag, sum);
smem_sum_.store(frag);
}
template<int NROWS, typename Operator>
__device__ inline void reduce_after_sync_(float (&frag)[NROWS][MMAS_M],
const int (&rows)[NROWS],
Operator &op, Smem_tile_red & smem_red) {
#pragma unroll
for (int ii = 0; ii < NROWS; ii++) {
typename Smem_tile_red::read_t tmp[MMAS_M];
smem_red.load_row(tmp, rows[ii]);
quad_allreduce(frag[ii], tmp, op);
}
}
template<int NROWS>
__device__ inline void reduce_sum_after_sync_(float (&frag)[NROWS][MMAS_M],
const int (&rows)[NROWS]){
SumOp<float> sum;
reduce_after_sync_(frag, rows, sum, smem_sum_);
}
template<int NROWS>
__device__ inline void reduce_max_after_sync_(float (&frag)[NROWS][MMAS_M],
const int (&rows)[NROWS]){
MaxOp<float> max;
reduce_after_sync_(frag, rows, max, smem_max_);
}
const uint32_t params_scale_bmm1_;
Smem_tile_red smem_max_;
Smem_tile_red smem_sum_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr);
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Row {};
struct Col {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, bool = (M & (M-1)) == 0 >
struct Next_power_of_two {
};
template< int M >
struct Next_power_of_two< M, true > { enum { VALUE = M }; };
template<>
struct Next_power_of_two< 3, false> { enum { VALUE = 4 }; };
template<>
struct Next_power_of_two< 5, false> { enum { VALUE = 8 }; };
template<>
struct Next_power_of_two< 6, false> { enum { VALUE = 8 }; };
template<>
struct Next_power_of_two< 7, false> { enum { VALUE = 8 }; };
template<>
struct Next_power_of_two< 9, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 10, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 11, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 12, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 13, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 14, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 15, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 24, false> { enum { VALUE = 32 }; };
template<>
struct Next_power_of_two< 48, false> { enum { VALUE = 64 }; };
template<>
struct Next_power_of_two< 80, false> { enum { VALUE = 128 }; };
template<>
struct Next_power_of_two< 96, false> { enum { VALUE = 128 }; };
template<>
struct Next_power_of_two<112, false> { enum { VALUE = 128 }; };
template<>
struct Next_power_of_two<144, false> { enum { VALUE = 256 }; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, bool = (N & (N-1)) == 0 >
struct Prev_power_of_two {
};
template< int N >
struct Prev_power_of_two< N, true > { enum { VALUE = N }; };
template<>
struct Prev_power_of_two< 3, false> { enum { VALUE = 2 }; };
template<>
struct Prev_power_of_two< 5, false> { enum { VALUE = 4 }; };
template<>
struct Prev_power_of_two< 6, false> { enum { VALUE = 4 }; };
template<>
struct Prev_power_of_two< 7, false> { enum { VALUE = 4 }; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, int N >
struct Div_up {
enum { VALUE = (M + N-1) / N };
};
constexpr int DivUpConstexpr(int M, int N) { return (M + N - 1) / N; }
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B >
struct Max {
enum { VALUE = A >= B ? A : B };
};
constexpr int MaxConstexpr(int A, int B) { return A >= B ? A : B; }
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B, int C >
struct Max_3 {
enum { VALUE = Max<Max<A, B>::VALUE, C>::VALUE };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B >
struct Min {
enum { VALUE = A <= B ? A : B };
};
constexpr int MinConstexpr(int A, int B) { return A <= B ? A : B; }
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int SIZE_IN_BYTES >
struct Uint_from_size_in_bytes {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<1> {
using Type = uint8_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<2> {
using Type = uint16_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<4> {
using Type = uint32_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<8> {
using Type = uint2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<16> {
using Type = uint4;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int WARPS_M, int WARPS_N, int WARPS_K >
struct Warp_masks {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; };
template<>
struct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; };
template<>
struct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; };
template<>
struct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; };
template<>
struct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; };
template<>
struct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; };
template<>
struct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; };
template<>
struct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; };
template<>
struct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; };
template<>
struct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; };
template<>
struct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; };
template<>
struct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; };
template<>
struct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; };
template<>
struct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; };
template<>
struct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; };
template<>
struct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; };
template<>
struct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename T >
inline __device__ __host__ T div_up(T m, T n) {
return (m + n-1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int clz(int x) {
for( int i = 31; i >= 0; --i ) {
if( (1 << i) & x ) {
return 31 - i;
}
}
return 32;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int find_log_2(int x, bool round_up = false) {
int a = 31 - clz(x);
if( round_up ) {
a += (x & (x-1)) ? 1 : 0;
}
return a;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) {
// uint32_t c;
// asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
// return c;
__half2 result = __hmul2(reinterpret_cast<const __half2 (&)>(a),
reinterpret_cast<const __half2 (&)>(b));
return reinterpret_cast<uint32_t(&)>(result);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hmul4(uint2 a, uint2 b) {
uint2 c;
c.x = hmul2(a.x, b.x);
c.y = hmul2(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hmul8(uint4 a, uint4 b) {
uint4 c;
c.x = hmul2(a.x, b.x);
c.y = hmul2(a.y, b.y);
c.z = hmul2(a.z, b.z);
c.w = hmul2(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hmul8(uint32_t a, uint4 b) {
uint4 c;
c.x = hmul2(a, b.x);
c.y = hmul2(a, b.y);
c.z = hmul2(a, b.z);
c.w = hmul2(a, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
uint32_t res;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb));
#else
const uint32_t zero = 0u;
asm volatile( \
"{\n" \
"\t .reg .f16x2 sela;\n" \
"\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
"\t and.b32 %0, sela, %1;\n"
"}\n" : "=r"(res) : "r"(x), "r"(zero));
#endif
return res;
}
static inline __device__ uint32_t habs2(uint32_t x) {
uint32_t res;
asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x));
return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
template< typename T >
static inline __device__ T clamp(T x, T lb, T ub) {
return x < lb ? lb : (x > ub ? ub : x);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t clamp_to_zero(uint16_t x) {
uint16_t mask;
asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x));
return mask & x;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t float_to_half(float f) {
uint16_t h;
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f));
return h;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float2_to_half2(float a, float b) {
uint32_t c;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a));
#else
uint16_t lo = float_to_half(a);
uint16_t hi = float_to_half(b);
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi));
#endif
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float_to_half2(float a) {
return float2_to_half2(a,a);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float2_to_half2(const float2 &f) {
return float2_to_half2(f.x, f.y);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) {
uint2 d;
d.x = float2_to_half2(x, y);
d.y = float2_to_half2(z, w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c));
#else
d = hrelu2(hfma2(a, b, c));
#endif
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t h0_h0(uint32_t x) {
uint32_t y;
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n"
: "=r"(y) : "r"(x));
return y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float h0_to_float(uint32_t h2) {
float f;
asm volatile("{\n" \
".reg .f16 lo, hi;\n" \
"mov.b32 {lo, hi}, %1;\n" \
"cvt.f32.f16 %0, lo;\n" \
"}\n" : "=f"(f) : "r"(h2));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t h1_h1(uint32_t x) {
uint32_t y;
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n"
: "=r"(y) : "r"(x));
return y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) {
uint16_t d;
asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) {
return hadd2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hadd4(uint2 a, uint2 b) {
uint2 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hadd(uint2 a, uint2 b) {
return hadd4(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hadd8(uint4 a, uint4 b) {
uint4 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
c.z = hadd2(a.z, b.z);
c.w = hadd2(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two half2's into float, then take their dot product.
// inline __device__ void hfma2_to_float(float &sum, const __half2 a, const __half2 b) {
static inline __device__ float hfma2_to_float(const __half2 a, const __half2 b) {
float2 af = __half22float2(a);
float2 bf = __half22float2(b);
return af.x * bf.x + af.y * bf.y;
// sum += af.x * bf.x + af.y * bf.y;
// sum = __fmaf_rn(sum, af.x, bf.x);
// sum = __fmaf_rn(sum, af.y, bf.y);
// float2 prod = __half22float2(__hmul2(a, b));
// sum += prod.x + prod.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's into float, then take their dot product.
static inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
float sum;
sum = fmha::hfma2_to_float(reinterpret_cast<const __half2&>(a.x),
reinterpret_cast<const __half2&>(b.x));
sum += fmha::hfma2_to_float(reinterpret_cast<const __half2&>(a.y),
reinterpret_cast<const __half2&>(b.y));
sum += fmha::hfma2_to_float(reinterpret_cast<const __half2&>(a.z),
reinterpret_cast<const __half2&>(b.z));
sum += fmha::hfma2_to_float(reinterpret_cast<const __half2&>(a.w),
reinterpret_cast<const __half2&>(b.w));
return sum;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 fadd4(uint4 a, uint4 b) {
float4 c;
c.x = reinterpret_cast<const float&>(a.x) + reinterpret_cast<const float&>(b.x);
c.y = reinterpret_cast<const float&>(a.y) + reinterpret_cast<const float&>(b.y);
c.z = reinterpret_cast<const float&>(a.z) + reinterpret_cast<const float&>(b.z);
c.w = reinterpret_cast<const float&>(a.w) + reinterpret_cast<const float&>(b.w);
return reinterpret_cast<const uint4&>(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 fmul4(uint4 a, float b) {
float4 c;
c.x = reinterpret_cast<const float &>(a.x) * b;
c.y = reinterpret_cast<const float &>(a.y) * b;
c.z = reinterpret_cast<const float &>(a.z) * b;
c.w = reinterpret_cast<const float &>(a.w) * b;
return reinterpret_cast<const uint4 &>(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hadd(uint4 a, uint4 b) {
return hadd8(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float half_to_float(uint16_t h) {
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float2 half2_to_float2(uint32_t x) {
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x));
return make_float2(half_to_float(lo), half_to_float(hi));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void half2_to_float2(float &x, float &y, uint32_t h) {
float2 tmp = half2_to_float2(h);
x = tmp.x;
y = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) {
uint16_t d;
asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) {
uint16_t d;
asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void uint4_to_ushort8(const uint4 a, uint16_t (&b)[8]) {
uint32_t *b_tmp = reinterpret_cast<uint32_t *>(&b[0]);
b_tmp[0] = a.x;
b_tmp[1] = a.y;
b_tmp[2] = a.z;
b_tmp[3] = a.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float sigmoid(float x) {
return 1.f / (1.f + expf(-x));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint16_t &dst) {
dst = uint16_t(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint32_t &dst) {
dst = 0u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint2 &dst) {
dst = make_uint2(0u, 0u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint4 &dst) {
dst = make_uint4(0u, 0u, 0u, 0u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// P R E D I C A T E P A C K I N G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE };
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// G E N E R I C P R E D I C A T E D L D G S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M, typename Functor >
inline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) {
// The number of complete bytes (where we use all the predicates in a byte).
enum { COMPLETE = N / PREDS_PER_BYTE };
// Make sure we did allocate enough predicates.
static_assert(Div_up<COMPLETE, BYTES_PER_REG>::VALUE <= M, "");
// The remainder.
enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE };
// Make sure we got the math right and the remainder is between 0 and 3.
static_assert(REMAINDER >= 0 && REMAINDER <= 3, "");
// The mask to extract the predicates.
enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 };
// Clear the fetch registers.
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
fct.clear(ii);
}
// Run complete steps.
bool p[PREDS_PER_BYTE];
#pragma unroll
for( int ii = 0; ii < COMPLETE; ++ii ) {
// The predicate.
uint32_t reg = preds[ii / BYTES_PER_REG];
// Extract the predicates.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj);
p[jj] = (reg & mask) != 0u;
}
// Issue the loads.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
fct.load(ii * PREDS_PER_BYTE + jj, p[jj]);
}
}
// Skip the rest of the code if we do not have a remainder.
if( REMAINDER > 0 ) {
// The mask to extract the predicates.
enum { REMAINDER_MASK = (1 << REMAINDER) - 1 };
// The predicate register.
uint32_t reg = preds[COMPLETE / BYTES_PER_REG];
// Extract the predicates.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj);
p[jj] = (reg & mask) != 0u;
}
// Issue the loads.
#pragma unroll
for( int ii = 0; ii < REMAINDER; ++ii ) {
fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, typename Functor >
inline __device__ void load_(Functor &fct, uint32_t preds) {
uint32_t tmp[1] = { preds };
load_<M>(fct, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint8_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint8_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint16_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint16_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint32_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint32_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint2 &dst, const void *ptr) {
dst = *reinterpret_cast<const uint2*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint4 &dst, const void *ptr) {
dst = *reinterpret_cast<const uint4*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type, int N >
struct Ldg_functor {
// Ctor.
inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N])
: fetch_(fetch), ptrs_(ptrs) {
}
// Clear the element.
inline __device__ void clear(int ii) {
fmha::clear(fetch_[ii]);
}
// Trigger the loads.
inline __device__ void load(int ii, bool p) {
if( p ) {
ldg(fetch_[ii], ptrs_[ii]);
}
}
// The fetch registers.
Data_type (&fetch_)[N];
// The pointers.
const void* (&ptrs_)[N];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type, int N, int M >
inline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
Ldg_functor<Data_type, N> fct(fetch, ptrs);
load_<N>(fct, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint8_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint16_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint32_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint2, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint4, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint16_t &dst, uint32_t ptr) {
asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint32_t &dst, uint32_t ptr) {
asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint2 &dst, uint32_t ptr) {
asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint4 &dst, uint32_t ptr) {
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x)
, "=r"(dst.y)
, "=r"(dst.z)
, "=r"(dst.w)
: "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S M
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n"
: "=r"(dst) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint32_t &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n"
: "=r"(dst) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint2 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint2 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint4 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint4 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint8_t val) {
*reinterpret_cast<uint8_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint16_t val) {
*reinterpret_cast<uint16_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint32_t val) {
*reinterpret_cast<uint32_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint2 val) {
*reinterpret_cast<uint2*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint4 val) {
*reinterpret_cast<uint4*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint16_t val) {
asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint32_t val) {
asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint2 val) {
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n"
:
: "r"(ptr)
, "r"(val.x)
, "r"(val.y));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint4 val) {
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n"
:
: "r"(ptr)
, "r"(val.x)
, "r"(val.y)
, "r"(val.z)
, "r"(val.w));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type, int N >
inline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) {
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
sts(ptrs[ii], data[ii]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) {
sts_<uint16_t, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) {
sts_<uint32_t, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) {
sts_<uint2, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {
sts_<uint4, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
};
template <>
struct MaxOp<__half2> {
__device__ inline __half2 operator()(__half2 const &x, __half2 const &y) { return __hmax2(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(__half2 (&dst)[M], __half2 (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_reduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(__half2 (&dst)[M], float2 (&src)[M], Operator &op) {
__half2 tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(reinterpret_cast<const __half2 &>(src[mi].x),
reinterpret_cast<const __half2 &>(src[mi].y));
}
quad_reduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = Allreduce<4>::run(dst[mi], op);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(__half2 (&dst)[M], __half2 (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = Allreduce<4>::run(dst[mi], op);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_allreduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(__half2 (&dst)[M], float2 (&src)[M], Operator &op) {
__half2 tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(reinterpret_cast<const __half2 &>(src[mi].x),
reinterpret_cast<const __half2 &>(src[mi].y));
}
quad_allreduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/* Copyright (c) 2022, Tri Dao.
*/
#include "fmha.h"
#include "fmha_block_dgrad_kernel_1xN_loop.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
__global__ void fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_block_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
}
template<typename Kernel_traits>
void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2);
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
static_assert(smem_size_dp_sum == 16 * 4 * 2);
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum;
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
bool is_causal = params.is_causal;
auto kernel = is_dropout
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false>)
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false>);
constexpr int N = Kernel_traits::Cta_tile_p::N;
if (params.s == N) {
kernel = is_dropout
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/1> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/1>)
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/1> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/1>);
} else if (params.s == N * 2) {
kernel = is_dropout
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/2> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/2>)
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
}
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_block_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
if (params.d == 16) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (params.d == 32) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>;
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (params.d == 64) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
}
\ No newline at end of file
/* Copyright (c) 2022, Tri Dao.
*/
#pragma once
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_kernel.h"
#include "fmha_blockmask.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Smem_dp_sum, int M>
inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const uint4 (&o)[M],
Smem_dp_sum smem, const int buffer_idx) {
#pragma unroll
for (int mi = 0; mi < M; ++mi) {
sum[mi] = smem.reduce_warp(fmha::hmulsum8(do_[mi], o[mi]));
}
static_assert(M == 1);
smem.store(sum[0], buffer_idx);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, Prng &ph,
const int loop_step_idx) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dq = typename Kernel_traits::Cta_tile_o;
// The description of the CTA tile for the 3rd batched GEMM.
using Cta_tile_dkv =
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dkv::M == 512 || Cta_tile_dkv::M == 256 || Cta_tile_dkv::M == 128);
static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64);
static_assert(Cta_tile_dkv::K == 16);
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dq = fmha::Hmma_tile<Cta_tile_dq>;
// The MMA tile for the 3rd GEMM.
using Mma_tile_dkv = fmha::Hmma_tile<Cta_tile_dkv>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to reload Q transposed.
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dkv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K^T. Treat K^T as V
using Smem_tile_kt = typename Kernel_traits::Smem_tile_v;
// Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load dO.
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
// The shared memory tile to load dO.
// Treating dO as Q.
using Smem_tile_do = typename Kernel_traits::Smem_tile_q;
// The shared memory tile to reload dO transposed.
using Smem_tile_dot = fmha::Smem_tile_b<Cta_tile_dkv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The global memory tile to load O.Loading O here is similar to loading dO.
using Gmem_tile_o = Gmem_tile_do;
// The global memory tile to store dQ.
// using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq;
using Gmem_tile_dq = fmha::Gmem_tile_dq<Cta_tile_dq>;
using Gmem_tile_dq_tmp = fmha::Gmem_tile_o<Cta_tile_dq, 4>;
// The shared memory tile to swizzle dQ.
using Smem_tile_dq = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store dV.
using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle dV.
using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dkv>;
// The global memory tile to store dK.
using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle dK.
using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dkv>;
static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);
static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
using Smem_dp_sum = typename Kernel_traits::Smem_dp_sum;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
// Shared memory.
extern __shared__ char smem_[];
// Shared memory layout if we keep V in registers:
// dO | Q | K / V | dQ | S | dP | dP_sum
// dV | dK
// Shared memory layout if we keep V shared memory:
// dO | Q | K | V | dQ | S | dP | dP_sum
// dV | dK
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
// if( binfo.stop_early() ) return;
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
Blockmask blockmask(params, loop_step_idx);
int block_row_idx = 0;
int mask_val = blockmask.mask_val(0);
if (mask_val == -1) return;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("mask_val = %d.\n", mask_val);
// }
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq gmem_dq(params, 0, binfo, tidx);
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!!
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params, binfo, tidx);
// Allocate the shared memory tile loader for dO.
Smem_tile_do smem_do(&smem_[0], tidx);
Smem_tile_dot smem_dot(&smem_[0], tidx);
// Allocate the shared memory tile loader for Q^T.
// TODO: assert that this points to the same memory as gemm_q_k.smem_q
Smem_tile_qt smem_qt(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
Smem_tile_st smem_s(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], tidx);
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
const int steps = params.s / Cta_tile_p::M;
// Wind gmem tiles to the correct position.
int block_row_idx_next = mask_val / 4;
int block_row_idx_to_move = block_row_idx_next - block_row_idx;
block_row_idx = block_row_idx_next;
gmem_q.move(block_row_idx_to_move);
gmem_do.move(block_row_idx_to_move);
gmem_o.move(block_row_idx_to_move);
gmem_dq.move(block_row_idx_to_move);
gmem_dq_tmp.move(block_row_idx_to_move);
// TODO: need to move gmem_s if we want the intermediate result for debugging
gmem_softmax_lse.move(block_row_idx_to_move);
gmem_softmax_d.move(block_row_idx_to_move);
block_row_idx = block_row_idx_next;
if (!Is_first) {
gmem_k.move(loop_step_idx);
gmem_v.move(loop_step_idx);
}
// Trigger the loads for K.
gmem_k.load();
// Trigger the loads for Q.
gmem_q.load();
// Trigger the loads for V.
gmem_v.load();
// Trigger the loads for dO.
gmem_do.load();
// Trigger the loads for O.
// if (Is_first) { gmem_o.load(); }
// if (true) { gmem_o.load(); }
if (Is_first || mask_val % 2 == 1) { gmem_o.load(); }
float p_lse[Mma_tile_p::MMAS_M * 2];
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
float dp_sum[Mma_tile_p::MMAS_M * 2];
// if (!Is_first) {
// if (false) {
if (!(Is_first || mask_val % 2 == 1)) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
}
float dp_sum_regs[Gmem_tile_do::LDGS];
Smem_dp_sum smem_dp_sum(reinterpret_cast<float *>(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE * 2]), tidx);
if (!Is_first) { __syncthreads(); }
// Commit the data for Q, dO, and V to shared memory.
gmem_q.commit(gemm_q_k.smem_q);
gmem_do.commit(smem_do);
// if (Is_first) {
// if (true) {
if (Is_first || mask_val % 2 == 1) {
dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, 0);
const int dp_sum_row = tidx / Smem_dp_sum::THREADS_PER_ROW;
if ((dp_sum_row < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) {
gmem_softmax_d.store_row(reinterpret_cast<uint32_t(&)[Gmem_tile_do::LDGS]>(dp_sum_regs), dp_sum_row);
}
}
// Instead of scaling dP by rp_dropout, we scale V instead
if (Is_dropout) {
const uint32_t scale_dropout = params.scale_dropout;
#pragma unroll
for(int it=0; it < Gmem_tile_v::LDGS; it++){
gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]);
}
}
gmem_v.commit(smem_v);
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
// #pragma unroll
// for(int it=0; it < Gmem_tile_k::LDGS; it++){
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
// }
// Commit the data for K to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_k.commit(gemm_q_k.smem_k);
}
__syncthreads();
// Load the fragments for Q.
gemm_q_k.load_q();
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Kernel_traits::V_IN_REGS ? Mma_tile_p::MMAS_K : 2][Mma_tile_p::MMAS_N];
if (Kernel_traits::V_IN_REGS) {
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_k.commit(gemm_q_k.smem_k);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for K.
gemm_q_k.load_k();
// Load the fragments for K^T.
// typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
// smem_kt.load(frag_kt[0], 0);
// typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N];
// #pragma unroll
// for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) {
// smem_kt.load(frag_kt[ki], ki);
// }
// Create the object to do the softmax.
// We won't be using the shared memory for this softmax at all
Softmax softmax(params, smem_, tidx);
// Declare the accumulators for the 3rd gemm.
fmha::Fragment_accumulator acc_dv[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dv);
fmha::Fragment_accumulator acc_dk[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length.
for( int l = 0; l < steps; l++ ) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("block_row_idx = %d\n", block_row_idx);
// }
if (block_row_idx * Cta_tile_p::M >= binfo.actual_seqlen) break;
int mask_val_next = l < steps - 1 ? blockmask.mask_val(l + 1) : -1;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("mask_val = %d, mask_val_next = %d\n", mask_val, mask_val_next);
// }
// Load the fragments for V.
// typename Smem_tile_v::Fragment frag_v[2][Mma_tile_p::MMAS_N];
if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[0], 0); }
// Load the fragments for dO.
typename Smem_tile_do::Fragment frag_do[2][Mma_tile_p::MMAS_M];
smem_do.load(frag_do[0], 0);
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
gemm_q_k(acc_p);
// Load the mask for that iteration.
mask.load(block_row_idx);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack_noscale(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
// Scale by log-sum-exp of the softmax
// softmax.apply_exp(p_lse);
softmax.template scale_apply_exp</*scale_max=*/false>(p_lse, params.scale_bmm1f);
if (Is_dropout) {
// softmax.apply_dropout(ph, params.p_dropout_in_uint);
// softmax.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint);
softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t);
}
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N);
softmax.pack(frag_p);
// Store s * dmask to smem for transpose
smem_s.store(frag_p);
// Trigger the load for the next Q values.
bool not_last_iter = (l < steps - 1) && (mask_val_next != -1);
block_row_idx_next = mask_val_next / 4;
int block_row_idx_to_move = block_row_idx_next - block_row_idx;
if (not_last_iter) {
gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move(block_row_idx_to_move);
gmem_q.load();
}
// if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
// __syncthreads();
// }
bool is_first_read = Is_first || mask_val % 2 == 1;
// TD [2022-04-24]: if Is_first, then it's faster to set acc_dp to zero then subtract by
// dp_sum later. If !Is_first, then it's faster to set acc_dp to -dp_sum and don't subtract
// later. This is because loading dp_sum earlier uses more registers.
fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
// if (Is_first) {
// if (true) {
if (is_first_read) {
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_dp);
} else {
#pragma unroll
for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
#pragma unroll
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
#pragma unroll
for (int ii = 0; ii < 8; ++ii) {
acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)];
}
}
}
}
// Do this part of dP^T = (dO * V^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of dO values.
smem_do.load(frag_do[ki & 1], ki);
if (!Kernel_traits::V_IN_REGS) {
smem_v.load(frag_v[ki & 1], ki);
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
// printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y);
// tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[(ki - 1) & 1]));
// printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y);
// }
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
if (!Kernel_traits::V_IN_REGS) {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
}
}
// Load the fragments for K^T.
typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
smem_kt.load(frag_kt[0], 0);
// if (Is_first) {
// if (true) {
if (is_first_read) {
const int quad = (tidx % Cta_tile_p::THREADS_PER_WARP) / 4;
const int row[2] = {quad, quad + 8};
smem_dp_sum.load(dp_sum, row, l % 2);
}
// Trigger the load for the next dO values.
if (not_last_iter) {
smem_do.move_to_next_write_buffer();
gmem_do.move(block_row_idx_to_move);
gmem_do.load();
gmem_o.move(block_row_idx_to_move);
// if (Is_first) {
// if (true) {
if (Is_first || mask_val_next % 2 == 1) {
gmem_o.load();
}
}
softmax.unpack_noscale(acc_dp);
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// // will be zero.
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; }
// if (Is_first) { softmax.subtract_dp_sum(dp_sum); }
// if (true) { softmax.subtract_dp_sum(dp_sum); }
if (is_first_read) { softmax.subtract_dp_sum(dp_sum); }
Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
softmax.pack(frag_dp);
if (!Is_dropout) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
frag_p[mi][ni].hmul(frag_dp[mi][ni]);
}
}
} else {
__half2 dp_sum_half[Mma_tile_p::MMAS_M * 2];
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
dp_sum_half[mi] = __float2half2_rn(dp_sum[mi]);
}
const __half zero_h = __half(0.f);
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
#pragma unroll
for (int ii = 0; ii < 4; ++ii) {
const __half2 p = frag_p[mi][ni].template elt_as<__half2>(ii);
const __half2 pdp = __hmul2(p, frag_dp[mi][ni].template elt_as<__half2>(ii));
// If this element is dropped, then frag_p stores -p instead of p.
// So pd holds -p * dp_sum in that case.
const __half2 pd = __hmul2(p, dp_sum_half[mi * 2 + (ii % 2)]);
const __half low = __low2half(p) >= zero_h ? __low2half(pdp) : __low2half(pd);
const __half high = __high2half(p) >= zero_h ? __high2half(pdp) : __high2half(pd);
frag_p[mi][ni].template elt_as<__half2>(ii) = __halves2half2(low, high);
}
}
}
}
// Store dp to smem for transpose
smem_dp.store(frag_p);
// gmem_s.store(frag_p, mask);
// gmem_s.move();
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_dq::WARPS_K>::apply(acc_dq);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_dq::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_kt.load(frag_kt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dq::MMAS_K;
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
static_assert(Gmem_tile_dq::LOOPS == 1);
// Swizzle the elements and do the final reduction.
smem_dq.store(acc_dq, 0);
typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N];
static_assert(Smem_tile_dot::Fragment::NUM_REGS == 4);
static_assert(Mma_tile_dkv::MMAS_K == 1);
smem_dot.load(frag_dot[0], 0);
// Threads in a warp is communicating via shared memory (smem_s and smem_dp)
__syncwarp();
typename Smem_tile_st::Fragment frag_s[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
smem_s.load(frag_s);
if (Is_dropout) {
#pragma unroll
for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
frag_s[ki][mi].hrelu_();
}
}
}
#pragma unroll
for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_dot.load(frag_dot[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
}
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if (not_last_iter) {
gmem_q.commit(gemm_q_k.smem_q);
}
uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP];
// if (!Is_first) { gmem_dq_tmp.load(dq_out, 0); }
if (!is_first_read) { gmem_dq_tmp.load(dq_out, 0); }
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if (not_last_iter) {
gmem_do.commit(smem_do);
// if (Is_first) {
// if (true) {
gmem_softmax_d.move(block_row_idx_to_move);
if (Is_first || mask_val_next % 2 == 1) {
// dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum);
// smem_dp_sum.move_to_next_write_buffer();
dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, (l + 1) % 2);
const int dp_sum_row_1 = tidx / Smem_dp_sum::THREADS_PER_ROW;
if ((dp_sum_row_1 < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) {
gmem_softmax_d.store_row(reinterpret_cast<uint32_t(&)[Gmem_tile_do::LDGS]>(dp_sum_regs), dp_sum_row_1);
}
}
gmem_softmax_lse.move(block_row_idx_to_move);
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
// if (!Is_first) {
if (!(Is_first || mask_val_next % 2 == 1)) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
}
}
typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
smem_dp.load(frag_dpt);
gemm_q_k.reload_k();
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dkv::MMAS_N];
static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);
static_assert(Mma_tile_dkv::MMAS_K == 1);
smem_qt.load(frag_qt[0], 0);
#pragma unroll
for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Make sure dQ is in shared memory.
__syncthreads();
// Load from shared memory.
is_first_read ? smem_dq.template load</*zero_init=*/true>(dq_out) : smem_dq.template load</*zero_init=*/false>(dq_out);
const bool is_final_write =
Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen)
|| ((mask_val & 0x2) != 0)
|| ((Is_causal) && (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
if (is_final_write) {
// if (Is_dropout) {
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
// }
dq_out[0] = fmha::fmul4(dq_out[0], params.scale_bmm1f);
// Output the values.
gmem_dq.store(dq_out, 0);
} else {
// Output the values.
gmem_dq_tmp.store(dq_out, 0);
}
// Move to the next part of the output.
gmem_dq.move(block_row_idx_to_move);
if (!(Is_first && Is_last)) { gmem_dq_tmp.move(block_row_idx_to_move); }
// // Make sure the data is in shared memory.
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if (not_last_iter) {
gemm_q_k.smem_q.move_to_next_read_buffer();
gemm_q_k.reload_q();
smem_qt.move_to_next_read_buffer();
// smem_qt.load(frag_qt[0], 0);
smem_do.move_to_next_read_buffer();
smem_dot.move_to_next_read_buffer();
// smem_dot.load(frag_dot[0], 0);
}
if (mask_val_next == -1) break;
mask_val = mask_val_next;
block_row_idx += block_row_idx_to_move;
} // Outer loop over the sequence length.
if (Is_dropout) {
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
acc_dv[mi][ni].mul_(params.rp_dropout);
}
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1));
// }
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
// acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f);
acc_dk[mi][ni].mul_(params.scale_bmm1f);
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1));
// }
__syncthreads();
// TODO [TD - 2022-05-04]: Are there cases where the shared mem for dV and dK are larger than
// the total amount of shared mem?
// Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[0], tidx);
smem_dv.store(acc_dv);
// Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx);
smem_dk.store(acc_dk);
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out);
Qkv_params dv_params;
dv_params.qkv_ptr = params.dqkv_ptr;
dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dv_params.h = params.h;
Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);
if (!Is_first) {
gmem_dv.move(loop_step_idx);
}
gmem_dv.store(dv_out);
uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out);
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// }
Qkv_params dk_params;
dk_params.qkv_ptr = params.dqkv_ptr;
dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dk_params.h = params.h;
Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);
if (!Is_first) {
gmem_dk.move(loop_step_idx);
}
gmem_dk.store(dk_out);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// loop_steps = -1 means the number of steps will be params.s / Kernel_traits::Cta_tile_p::N.
// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2.
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1, typename Params>
inline __device__ void compute_block_dq_dk_dv_1xN(const Params &params) {
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const int tidx_global = (bidb * params.h + bidh) * blockDim.x + tidx;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
if (loop_steps == 1) {
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
} else if (loop_steps == 2) {
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, 1);
} else {
if (params.s == N_per_loop) {
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
} else {
const int max_loop_steps = (params.s + N_per_loop - 1) / N_per_loop;
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false>(params, ph, loop_step_idx);
}
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, max_loop_steps - 1);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_block_fprop_kernel_1xN.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
__global__ void fmha_block_fprop_fp16_sm80_loop_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_block_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
}
template<typename Kernel_traits>
void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params,
const bool configure) {
bool is_causal = launch_params.params.is_causal;
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
auto kernel = launch_params.is_dropout
? (is_causal
? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, false>)
: (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, false>))
: (is_causal
? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, false>)
: (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, false>));
constexpr int N = Kernel_traits::Cta_tile_p::N;
const int loop_steps = (launch_params.params.s + N - 1) / N;
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
// Don't need smem_size_softmax_lse if we're not looping
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
if (configure) {
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
constexpr int M = Kernel_traits::Cta_tile_p::M;
size_t STEPS = (launch_params.params.s + M - 1) / M;
constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
launch_params.elts_per_thread = elts_per_head;
return;
}
dim3 grid(launch_params.params.h, launch_params.params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_block_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params,
const bool configure) {
if (launch_params.params.d == 16) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>;
run_fmha_block_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if (launch_params.params.d == 32) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>;
run_fmha_block_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if (launch_params.params.d == 64) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
run_fmha_block_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
}
\ No newline at end of file
/***************************************************************************************************
* Copyright (c) 2022, Tri Dao.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_kernel.h"
#include "fmha_blockmask.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void device_block_1xN_(const Params &params, const int bidb, const int bidh, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
using Gmem_tile_o_tmp = fmha::Gmem_tile_o<Cta_tile_o, 4>;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum;
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
// if( binfo.stop_early() ) return;
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
Blockmask blockmask(params, loop_step_idx);
int block_row_idx = 0;
int mask_val = blockmask.mask_val(0);
if (mask_val == -1) return;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("mask_val = %d.\n", mask_val);
// }
Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
// Wind gmem tiles to the correct position.
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
int block_row_idx_next = mask_val / 4;
int block_row_idx_to_move = block_row_idx_next - block_row_idx;
gmem_q.move(block_row_idx_to_move);
gmem_o.move(block_row_idx_to_move);
gmem_o_tmp.move(block_row_idx_to_move);
if (Return_softmax) { gmem_s.move(block_row_idx_to_move); }
gmem_softmax_lse.move(block_row_idx_to_move);
block_row_idx = block_row_idx_next;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("begin = %d, steps = %d\n", begin, steps);
// }
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx);
if (!Is_first) {
gmem_k.move(loop_step_idx);
gmem_v.move(loop_step_idx);
if (Return_softmax) { gmem_s.move(loop_step_idx * steps); }
}
// Trigger the loads for K.
gmem_k.load();
// Trigger the loads for Q.
gmem_q.load();
// Trigger the loads for V.
gmem_v.load();
if (!Is_first) { __syncthreads(); }
float p_prev_lse[Mma_tile_p::MMAS_M * 2];
if (!(Is_first || mask_val % 2 == 1)) {
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
}
// Commit the data for Q and V to shared memory.
gmem_q.commit(gemm_q_k.smem_q);
gmem_v.commit(smem_v);
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
// #pragma unroll
// for(int it=0;it < Gmem_tile_k::LDGS;it++){
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
// }
// Commit the data for K to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_k.commit(gemm_q_k.smem_k);
}
__syncthreads();
// Load the fragments for Q.
gemm_q_k.load_q();
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_k.commit(gemm_q_k.smem_k);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for K.
gemm_q_k.load_k();
// Create the object to do the softmax.
Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx);
Smem_softmax_sum smem_softmax_lse(reinterpret_cast<float *>(&smem_[Gemm1::SMEM_BYTES]), tidx);
// Load over the entire sequence length.
for( int l = 0; l < steps; l++ ) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("block_row_idx = %d\n", block_row_idx);
// }
if (block_row_idx * Cta_tile_p::M >= binfo.actual_seqlen) break;
int mask_val_next = l < steps - 1 ? blockmask.mask_val(l + 1) : -1;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("mask_val = %d, mask_val_next = %d\n", mask_val, mask_val_next);
// }
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P = Q * K^T.
gemm_q_k(acc_p);
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
bool is_first_read = Is_first || mask_val % 2 == 1;
// if (!Is_first) { gmem_o_tmp.load(out, 0); }
if (!is_first_read) { gmem_o_tmp.load(out, 0); }
// Trigger the load for the next Q values.
bool not_last_iter = (l < steps - 1) && (mask_val_next != -1);
block_row_idx_next = mask_val_next / 4;
int block_row_idx_to_move = block_row_idx_next - block_row_idx;
if (not_last_iter) {
gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move(block_row_idx_to_move);
gmem_q.load();
}
// Load the mask for that iteration.
mask.load(block_row_idx);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack_noscale(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
// softmax.unpack_noscale_half_and_apply_mask(acc_p, mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]);
// }
// }
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
// if (!Is_first) {
if (!is_first_read) {
smem_softmax_lse.store_pair(p_prev_lse, l % 2);
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; }
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; }
}
// Trigger the load for the next LSE values.
if (not_last_iter) {
// if (!Is_first) {
if (!(Is_first || mask_val_next % 2 == 1)) {
gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse),
block_row_idx_to_move);
}
}
// __half2 p_max[Mma_tile_p::MMAS_M];
// softmax.template reduce_max</*zero_init=*/Is_first>(p_max);
is_first_read ? softmax.template reduce_max</*zero_init=*/true>(p_max) : softmax.template reduce_max</*zero_init=*/false>(p_max);
// if ((threadIdx.x == 0) && (l == 38)) {
// printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]);
// }
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
// }
// }
// Compute the exponential value.
// softmax.apply_exp(p_max);
softmax.scale_apply_exp(p_max, params.scale_bmm1f);
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
// }
// }
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
// if (!Is_first) {
// int warp = tidx / Cta_tile_p::THREADS_PER_WARP;
// int lane = tidx % Cta_tile_p::THREADS_PER_WARP;
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
// p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? expf(p_prev_lse[mi] - p_max[mi]) : 0;
// }
// }
// softmax.reduce_sum(p_sum);
softmax.reduce_sum_before_sync_(p_sum);
// softmax.template reduce_sum_before_sync_</*zero_init=*/Is_first>(p_sum);
// float p_sum_log[Mma_tile_p::MMAS_M * 2];
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) {
// float sum = p_sum[mi];
// // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + __logf(sum);
// constexpr float kLog2e = M_LOG2E;
// p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum);
// }
// // gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum));
// gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum_log));
// gmem_softmax_lse.move();
// // Finalize softmax on the accumulators of P^T.
// softmax.scale(p_sum);
constexpr bool encode_dropout_in_sign_bit = Return_softmax;
if (Is_dropout) {
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph0, params.p_dropout_in_uint);
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint);
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint16_t);
}
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
softmax.pack(frag_p);
if (Return_softmax) {
gmem_s.store(frag_p, mask);
if (not_last_iter) {
gmem_s.move(block_row_idx_to_move);
}
}
// Commit the values for Q into shared memory.
if (not_last_iter) {
gmem_q.commit(gemm_q_k.smem_q);
}
if (Is_dropout && encode_dropout_in_sign_bit) {
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
frag_p[ki][mi].hrelu_();
}
}
}
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);
}
// The mapping from tidx to rows changes between the softmax and the O-reduction.
// So we recalculate the max.
float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
// TODO: not sure if this is right for seqlen 128 or 256
int rows[Gmem_tile_o::STGS_PER_LOOP];
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
}
softmax.reduce_max_after_sync_(p_max_o, rows);
static_assert(Mma_tile_o::MMAS_M == 1);
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
p_max_o[jj][0] *= params.scale_bmm1f;
}
float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP];
// if (!Is_first) { smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); }
if (!is_first_read) { smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); }
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]);
// }
// }
static_assert(Gmem_tile_o::LOOPS == 1);
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, 0);
// Make sure the data is in shared memory.
__syncthreads();
static_assert(Mma_tile_o::MMAS_M == 1);
float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
softmax.reduce_sum_after_sync_(p_sum_o, rows);
// if (!Is_first) {
if (!is_first_read) {
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]);
p_sum_o[jj][0] += p_prev_scale_o[jj];
}
}
float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
#pragma unroll
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
float sum = p_sum_o[jj][0];
p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum);
// if (sum == 0.f || sum != sum) {
// printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]);
// }
// if (Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("p_sum_log=%.6f\n", p_sum_log[jj][0]);
// }
// }
if ((tidx % Gmem_tile_o::THREADS_PER_ROW == 0) && (tidx / Gmem_tile_o::THREADS_PER_ROW < Gmem_tile_o::ROWS)) {
gmem_softmax_lse.store_row(
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
}
}
if (not_last_iter) {
gmem_softmax_lse.move(block_row_idx_to_move);
}
// Load from shared memory.
// if (!Is_first) {
if (!is_first_read) {
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]);
}
}
// smem_o.template load</*zero_init=*/Is_first>(out);
is_first_read ? smem_o.template load</*zero_init=*/true>(out) : smem_o.template load</*zero_init=*/false>(out);
const bool is_final_write =
Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen)
|| ((mask_val & 0x2) != 0)
|| ((Is_causal) && (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("is_final_write = %d\n", is_final_write);
// }
#pragma unroll
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
float sum = p_sum_o[jj][0];
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
if (Is_dropout && is_final_write) {
inv_sum *= params.rp_dropout;
}
out[jj] = fmha::fmul4(out[jj], inv_sum);
}
// if (Is_dropout && Is_last) {
// for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
// out[jj] = fmha::fmul4(out[jj], params.rp_dropout);
// }
// }
// Output the values.
if (is_final_write) {
gmem_o.store(out, 0);
} else {
gmem_o_tmp.store(out, 0);
}
// Move to the next part of the output.
gmem_o.move(block_row_idx_to_move);
if (!(Is_first && Is_last)) { gmem_o_tmp.move(block_row_idx_to_move); }
gemm_q_k.reload_k();
// Make sure we are reading from the correct buffer.
gemm_q_k.smem_q.move_to_next_read_buffer();
// Trigger the load from shared memory for the next series of Q values.
if (not_last_iter) {
gemm_q_k.reload_q();
}
if (mask_val_next == -1) break;
mask_val = mask_val_next;
block_row_idx += block_row_idx_to_move;
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, typename Params>
inline __device__ void device_block_1xN_loop(const Params &params) {
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds));
const int STEPS = params.s / Kernel_traits::Cta_tile_p::M;
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N;
if (params.s == N_per_loop) {
fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, ph0, ph1, 0);
} else {
const int max_loop_steps = (params.s + N_per_loop - 1) / N_per_loop;
fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, ph0, ph1, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, ph0, ph1, loop_step_idx);
}
fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, STEPS, ph0, ph1, max_loop_steps - 1);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <fmha.h>
#include <fmha/utils.h>
#include <fmha/smem_tile.h>
#include <fmha/gmem_tile.h>
#include <fmha/mask.h>
#include <fmha/softmax.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Blockmask {
template<typename Params>
__device__ Blockmask(const Params &params, int loop_step_idx) :
blockmask_ptr(params.blockmask + loop_step_idx * params.s / 16) {
}
__device__ int mask_val(int block_row_idx) const {
return blockmask_ptr[block_row_idx];
}
const int *blockmask_ptr;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
// #include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_recompute.h"
using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
// extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_dv_kernel(Fused_multihead_attention_fprop_params params) {
// fmha::compute_dv_1xN<Kernel_traits>(params);
// }
// extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_dq_dk_kernel(Fused_multihead_attention_fprop_params params) {
// fmha::compute_dq_dk_1xN<Kernel_traits>(params);
// }
extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_dp_dq_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dp_dq_1xN<Kernel_traits>(params);
}
extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_dv_dk_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_dk_1xN<Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 512 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
// constexpr int smem_size_dp_dq = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
// constexpr int smem_size_dv_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size_dp_dq = smem_size_q * 2 + smem_size_q + smem_size_v + smem_size_o;
constexpr int smem_size_dv_dk = smem_size_q + smem_size_q + smem_size_v + smem_size_o + smem_size_s;
if( smem_size_dp_dq >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
// fmha_dgrad_fp16_512_64_sm80_dv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dp_dq));
fmha_dgrad_fp16_512_64_sm80_dp_dq_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dp_dq));
}
if( smem_size_dv_dk >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_512_64_sm80_dv_dk_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dv_dk));
}
dim3 grid(params.h, params.b);
// fmha_dgrad_fp16_512_64_sm80_dv_kernel<<<grid, Kernel_traits::THREADS, smem_size_dp_dq, stream>>>(params);
// fmha_dgrad_fp16_512_64_sm80_dp_dq_kernel<<<grid, Kernel_traits::THREADS, smem_size_dp_dq, stream>>>(params);
fmha_dgrad_fp16_512_64_sm80_dp_dq_kernel<<<grid, Kernel_traits::THREADS, smem_size_dp_dq, stream>>>(params);
fmha_dgrad_fp16_512_64_sm80_dv_dk_kernel<<<grid, Kernel_traits::THREADS, smem_size_dv_dk, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment