Unverified Commit 5c9b21d8 authored by yjk21's avatar yjk21 Committed by GitHub
Browse files

adds fmhalib (#1074)

parent e5f2f675
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "fmha.h"
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params,
bool is_training,
cudaStream_t stream);
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params,
cudaStream_t stream);
void set_params(Fused_multihead_attention_fprop_params &params,
// sizes
const size_t b,
const size_t s,
const size_t h,
const size_t d,
// device pointers
void *qkv_packed_d,
void *cu_seqlens_d,
void *seqlens_d,
void *o_packed_d,
void *s_d,
float p_dropout) {
Data_type acc_type = DATA_TYPE_FP32;
Data_type data_type = DATA_TYPE_FP16;
// Reset the parameters
memset(&params, 0, sizeof(params));
// Set the pointers and strides.
params.qkv_ptr = qkv_packed_d;
params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type);
params.o_ptr = o_packed_d;
params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);
params.cu_seqlens = static_cast<int *>(cu_seqlens_d);
params.seqlens = static_cast<int *>(seqlens_d);
// S = softmax(P)
params.s_ptr = s_d;
params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type);
// Set the dimensions.
params.b = b;
params.h = h;
params.s = s;
params.d = d;
// Set the different scale values.
const float scale_bmm1 = 1.f / sqrtf(d);
constexpr float scale_softmax = 1.f;
constexpr float scale_bmm2 = 1.f;
set_alpha(params.scale_bmm1, scale_bmm1, acc_type);
set_alpha(params.scale_softmax, scale_softmax, acc_type);
set_alpha(params.scale_bmm2, scale_bmm2, data_type);
// Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout;
params.rp_dropout = 1.f / params.p_dropout;
TORCH_CHECK(p_dropout < 1.f);
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
}
constexpr uint32_t NUM_HEADS_DIM = 2;
constexpr uint32_t THREE_DIM = 1;
std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const at::Tensor &seqlens, // b
const float p_dropout,
const int max_seq_len,
const bool is_training,
c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80;
if( max_seq_len <= 128 ) {
seq_len = 128;
launch = &run_fmha_fp16_128_64_sm80;
} else if( max_seq_len <= 256 ) {
seq_len = 256;
launch = &run_fmha_fp16_256_64_sm80;
} else if( max_seq_len <= 384 ) {
seq_len = 384;
launch = &run_fmha_fp16_384_64_sm80;
} else if( max_seq_len <= 512 ) {
seq_len = 512;
launch = &run_fmha_fp16_512_64_sm80;
} else {
TORCH_CHECK(false);
}
constexpr int warps_m = 1;
constexpr int warps_n = 4; // this leads to an upper bound
const int mmas_m = seq_len / 16 / warps_m;
const int mmas_n = seq_len / 16 / warps_n;
const int elts_per_thread = 8 * mmas_m * mmas_n;
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
TORCH_CHECK(seqlens.numel() == batch_size);
const int total = sizes[0];
const int num_heads = sizes[NUM_HEADS_DIM];
const int head_size = sizes[3];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();
auto ctx = torch::empty({ total, num_heads, head_size }, opts);
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
if( is_training ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
launch(params, is_training, stream);
return { ctx, s };
}
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1
const at::Tensor &seqlens, // b
const float p_dropout, // probability to drop
const int max_seq_len // max sequence length to choose the kernel
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
int seq_len = 512;
auto launch = &run_fmha_dgrad_fp16_512_64_sm80;
if( max_seq_len <= 128 ) {
seq_len = 128;
launch = &run_fmha_dgrad_fp16_128_64_sm80;
} else if( max_seq_len <= 256 ) {
seq_len = 256;
launch = &run_fmha_dgrad_fp16_256_64_sm80;
} else if( max_seq_len <= 384 ) {
seq_len = 384;
launch = &run_fmha_dgrad_fp16_384_64_sm80;
} else if( max_seq_len <= 512 ) {
seq_len = 512;
launch = &run_fmha_dgrad_fp16_512_64_sm80;
} else {
TORCH_CHECK(false);
}
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(dout.dtype() == torch::kFloat16);
TORCH_CHECK(softmax.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda());
TORCH_CHECK(cu_seqlens.is_cuda());
TORCH_CHECK(qkv.is_contiguous());
TORCH_CHECK(cu_seqlens.is_contiguous());
TORCH_CHECK(seqlens.is_contiguous());
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
TORCH_CHECK(seqlens.numel() == batch_size);
const int num_heads = sizes[NUM_HEADS_DIM];
const int head_size = sizes[3];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto dqkv = torch::empty_like(qkv);
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
seqlens.data_ptr(),
dout.data_ptr(), // we set o_ptr to dout
softmax.data_ptr(), // softmax gets overwritten by dP!
p_dropout);
// we're re-using these scales scales
Data_type acc_type = DATA_TYPE_FP32;
set_alpha(params.scale_bmm1, 1.f, acc_type);
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);
params.dqkv_ptr = dqkv.data_ptr();
launch(params, stream);
return { dqkv, softmax };
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention for BERT";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass");
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <fmha_utils.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
// The QKV matrices.
void *qkv_ptr;
// The stride between rows of the Q, K and V matrices.
size_t qkv_stride_in_bytes;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices.
void *dqkv_ptr;
// The O matrix (output).
void *o_ptr;
// The stride between rows of O.
int64_t o_stride_in_bytes;
// The pointer to the S matrix, overwritten by the dP matrix (bwd).
void *s_ptr;
// The stride between rows of the S matrix.
int64_t s_stride_in_bytes;
// The dimensions.
int b, h, s, d;
// The scaling factors for the kernel.
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
// array of length b+1 holding starting offset of each sequence.
int *cu_seqlens;
// array of length b holding the actual sequence lenghts.
int *seqlens;
// The dropout probability (probability of keeping an activation).
float p_dropout;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
// Scale factor of 1 / (1 - p_dropout), in half2.
uint32_t scale_dropout;
// Random state.
at::PhiloxCudaState philox_args;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <fmha/utils.h>
#define FMHA_DIV_UP(m, n) (((m) + (n)-1) / (n))
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ >
struct Fragment_base_ {
// The data type.
using Data_type = Data_type_;
// default input type
using Input_type_ = Data_type_;
// Does it store the array of elements.
enum { HAS_ELTS = BITS_PER_ELT_ >= 8 };
// The number of elements.
enum { NUM_ELTS = NUM_ELTS_ };
// The size of element in bits.
enum { BITS_PER_ELT = BITS_PER_ELT_ };
// The size of byte of a single register.
enum { BYTES_PER_REG = 4 };
// The size in bits.
enum { BITS_PER_REG = BYTES_PER_REG * 8 };
// The number of registers needed to store the fragment.
enum { NUM_REGS = Div_up<NUM_ELTS * BITS_PER_ELT, BITS_PER_REG>::VALUE };
// The size in bytes (as returned by sizeof(Fragment_base<>).
enum { SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG };
// The alignment.
enum { ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : Min<NUM_REGS * BYTES_PER_REG, 16>::VALUE };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The type of the elements.
typename Data_type_,
// The number of elements.
int NUM_ELTS_,
// The alignment if you want to force a value -- use 0 otherwise.
int ALIGNMENT_ = 0,
// The base class.
typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_>
>
struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
// The size of a load/store.
enum { BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t) };
// Clear the fragment. Using PTX in that code seems to produce better SASS...
inline __device__ void clear() {
#pragma unroll
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : );
}
}
// Immutable access to a register.
inline __device__ const uint32_t& reg(int ii) const {
return this->regs_[ii];
}
// Mutable access to a register.
inline __device__ uint32_t& reg(int ii) {
return this->regs_[ii];
}
uint32_t regs_[Base_::NUM_REGS];
// Immutable access to the elements.
inline __device__ const Data_type_& elt(int ii) const {
return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];
}
// Mutable access to the elements.
inline __device__ Data_type_& elt(int ii) {
return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];
}
// Immutable access to the elements with a cast.
template< typename Cast_type >
inline __device__ const Cast_type& elt_as(int ii) const {
return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];
}
// Mutable access to the elements.
template< typename Cast_type >
inline __device__ Cast_type& elt_as(int ii) {
return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];
}
// Add another fragment.
inline __device__ void add(const Fragment &other) {
#pragma unroll
for( int ii = 0; ii < NUM_ELTS_; ++ii ) {
this->elt(ii) += other.elt(ii);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Layout >
struct Fragment_a : public Fragment<uint16_t, 8> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Layout >
struct Fragment_b : public Fragment<uint16_t, 8> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fragment_accumulator : public Fragment<float, 8> {
// The base class.
using Base = Fragment<float, 8>;
// Add two fragments.
template< typename Other_fragment_ >
inline __device__ void add(const Other_fragment_ &other) {
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
this->elt(ii) = this->elt(ii) + other.elt(ii);
}
}
// Do the HMMA.
template< typename Layout_a, typename Layout_b >
inline __device__ void mma(const Fragment_a<Layout_a> &a,
const Fragment_b<Layout_b> &b) {
asm volatile( \
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
" {%0, %1, %2, %3}, \n" \
" {%4, %5, %6, %7}, \n" \
" {%8, %9}, \n" \
" {%0, %1, %2, %3}; \n" \
: "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3))
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
, "r"(b.reg(0)), "r"(b.reg(1)));
asm volatile( \
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
" {%0, %1, %2, %3}, \n" \
" {%4, %5, %6, %7}, \n" \
" {%8, %9}, \n" \
" {%0, %1, %2, %3}; \n" \
: "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7))
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
, "r"(b.reg(2)), "r"(b.reg(3)));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Fragment, int M, int N >
inline __device__ void clear(Fragment (&frag)[M][N]) {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < N; ++ni ) {
frag[mi][ni].clear();
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Accumulator_type, int WARPS_K >
struct Clear_accumulator {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int WARPS_K >
struct Clear_accumulator<float, WARPS_K> {
template< typename Acc, int M, int N >
static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {
fmha::clear(acc);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B, int M, int N>
inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < N; ++ni ) {
acc[mi][ni].mma(a[mi], b[ni]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The number of rows in the CTA tile.
int M_,
// The number of cols in the CTA tile.
int N_,
// The number of elements in the the K dimension of the GEMM loop.
int K_,
// The number of rows of warps.
int WARPS_M_,
// The number of cols of warps.
int WARPS_N_,
// The number of warps in the K dimension of the GEMM loop.
int WARPS_K_>
struct Cta_tile_ {
enum { M = M_, N = N_, K = K_ };
// The number of warps.
enum { WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_ };
// The number of warps per CTA.
enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K };
// The number of threads per warp.
enum { THREADS_PER_WARP = 32 };
// The number of threads per CTA.
enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Hmma_tile {
// The number of elements computed with a single warp-MMA.
enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 };
// The number of elements computed with a single CTA-MMA.
enum {
M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K
};
// The number of MMAs needed to compute the GEMM.
enum {
MMAS_M = Div_up<Cta_tile::M, M_PER_MMA_PER_CTA>::VALUE,
MMAS_N = Div_up<Cta_tile::N, N_PER_MMA_PER_CTA>::VALUE,
MMAS_K = Div_up<Cta_tile::K, K_PER_MMA_PER_CTA>::VALUE,
};
// The number of elements computed per warp.
enum {
M_PER_WARP = MMAS_M * M_PER_MMA,
N_PER_WARP = MMAS_N * N_PER_MMA,
K_PER_WARP = MMAS_K * K_PER_MMA,
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using A_type = uint16_t;
using B_type = uint16_t;
using C_type = uint16_t;
using Accumulator_type = float;
using Epilogue_type = float;
constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;
constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;
constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>
using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile_>
using Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M,
Cta_tile_::N,
Next_power_of_two<Cta_tile_::K>::VALUE,
Cta_tile_::WARPS_M,
Cta_tile_::WARPS_N,
Cta_tile_::WARPS_K>;
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The number of bits per element.
int BITS_PER_ELEMENT,
// The number of rows of Q, K or V loaded by this tile.
int ROWS,
// The number of columns.
int COLS
>
struct Gmem_tile_qkv {
// The size of each LDG.
enum { BYTES_PER_LDG = 16 };
// The size of a row in bytes.
enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 };
// The number of threads to load a "row" of the matrix.
enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG };
// The number of "rows" loaded per LDG.
enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
// The number of LDGs needed to load a chunk of the Q matrix.
enum { LDGS = fmha::Div_up<ROWS, ROWS_PER_LDG>::VALUE };
// Ctor.
template< typename Params, typename BInfo >
inline __device__ Gmem_tile_qkv(const Params &params, int qkv_offset, const BInfo &binfo, int tidx)
: params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes)
, actual_seqlen(binfo.actual_seqlen)
, qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr)) {
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % THREADS_PER_ROW;
// Store the row as we need it to disable the loads.
row_ = row;
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
// Add the block index.
row_offset += (int64_t)((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
// Assemble the final pointer.
qkv_ptr_ += row_offset + col * BYTES_PER_LDG;
}
// Store data to shared memory.
template< typename Smem_tile >
inline __device__ void commit(Smem_tile &smem_tile) {
smem_tile.store(fetch_);
}
// Load data from memory.
template< typename Smem_tile >
inline __device__ void load(Smem_tile &smem_tile) {
const void *ptrs[LDGS];
uint32_t preds[LDGS];
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));
fetch_[ii] = make_uint4(0, 0, 0, 0);
}
// not packing predicates removes restrictions (e.g. FP16 384, 4 warps)
Ldg_functor<uint4, LDGS> fct(fetch_, ptrs);
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
fct.load(ii, preds[ii]);
}
}
// Store data to memory.
inline __device__ void store(const uint4 (&data)[LDGS]) {
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
char *ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
fmha::stg(ptr, data[ii]);
}
}
}
// Move the pointer to the next location.
inline __device__ void move() {
qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_;
actual_seqlen -= ROWS;
}
// The stride between rows for the QKV matrice.
int64_t params_qkv_stride_in_bytes_;
// The pointer.
char *qkv_ptr_;
// The fetch registers.
uint4 fetch_[LDGS];
// Keep track of the row the thread is processing as we move the tile.
int row_;
// The length of the sequence loaded by that memory tile.
int actual_seqlen;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile >
struct Gmem_tile_o {
// The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The size of each element.
enum { BYTES_PER_ELEMENT = 2 };
// The size of a row in bytes.
enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT };
// The number of threads to store a "row" of the matrix.
enum { THREADS_PER_ROW = 16 };
// The size of each STG.
enum { BYTES_PER_STG = BYTES_PER_ROW / THREADS_PER_ROW };
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
enum { ROWS = Cta_tile::M };
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA };
// The number of outter loop for the stores.
enum { LOOPS = ROWS / ROWS_PER_LOOP };
// The number of "rows" stored per STG.
enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
// Do we have to guard against partial writes/reads.
enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 };
// The number of STGs needed to store a chunk of the Q matrix.
enum { STGS_PER_LOOP = fmha::Div_up<ROWS_PER_LOOP, ROWS_PER_STG>::VALUE };
// The number of STGs needed to store a chunk of the Q matrix in total.
enum { STGS = STGS_PER_LOOP * LOOPS };
// Ctor.
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_o(const Params &params, const BInfo &binfo, int tidx)
: params_o_stride_in_bytes_(params.o_stride_in_bytes)
, actual_seqlen_(binfo.actual_seqlen)
, o_ptr_(reinterpret_cast<char *>(params.o_ptr)) {
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % THREADS_PER_ROW;
// Store the row as we need it to disable loads.
row_ = row;
// The row offset in the batched GEMM.
int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
// Assemble the final pointer.
o_ptr_ += row_offset + col * BYTES_PER_STG;
// Is that thread active on the last STG?
if( HAS_INCOMPLETE_STG ) {
is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M;
}
}
// Store data to global memory.
inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
if( this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_ ) {
break;
}
float x = reinterpret_cast<const float &>(src[ii].x);
float y = reinterpret_cast<const float &>(src[ii].y);
float z = reinterpret_cast<const float &>(src[ii].z);
float w = reinterpret_cast<const float &>(src[ii].w);
uint2 out = float4_to_half4(x, y, z, w);
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out);
}
}
}
// Move the pointer to the next location.
inline __device__ void move() {
row_ += ROWS;
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_;
}
// The stride between rows for the QKV matrice.
int64_t params_o_stride_in_bytes_;
// The pointer.
char *o_ptr_;
// Is the thread active for the last STG?
int is_active_for_last_stg_;
// Keep track of the row to disable loads.
int row_;
// The length of the sequence loaded by that memory tile.
int actual_seqlen_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, int BYTES_PER_ELEMENT >
struct Gmem_tile_mma_sd {
// The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// Each STG stores 8 elements.
enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 };
// The number of MMAs in the M dimension.
enum { MMAS_M = Mma_tile::MMAS_M };
// The number of MMAs in the N dimension.
enum { MMAS_N = Mma_tile::MMAS_N };
// The number of rows computed per MMA per thread block.
enum { M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA };
// The number of cols computed per MMA per thread block.
enum { N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA };
// The number of threads per block.
enum { THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA };
// The size of each row in bytes. I.e. how many bytes are stored per STG.
enum { BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG };
// The fixed sequence length.
enum { SEQLEN = Cta_tile::N };
// The distance between two blocks (in bytes).
enum { BLOCK_STRIDE_BYTES = SEQLEN * SEQLEN * BYTES_PER_ELEMENT };
// The distance between elements stored per loop (in bytes).
enum { LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW };
// The type of elements stored per STG.
using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;
// Ctor.
template<typename Params>
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int tidx)
: ptr_(static_cast<char *>(ptr)) {
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The block index.
size_t bidx = bidb * params.h + bidh;
// Set store location for each thread at the beginning of the loop
ptr_ += bidx * BLOCK_STRIDE_BYTES + tidx * BYTES_PER_STG;
}
// Store to global memory.
inline __device__ void store(const Type &data, const int mi, const int ni) {
size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
fmha::stg(ptr_ + offset, data);
}
// Load from global memory.
inline __device__ void load(Type &data, const int mi, const int ni) {
size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
fmha::ldg(data, ptr_ + offset);
}
// Move to the next tile.
inline __device__ void move() {
ptr_ += LOOP_STRIDE_BYTES;
}
// The pointer in global memory.
char *ptr_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >
struct Gmem_tile_mma_s : public Base {
// The number of mmas in the vertical dimension.
enum { M = Base::MMAS_M };
// The number of mmas in the horizontal dimension.
enum { N = Base::MMAS_N };
// The type of the vectors stored by each STG.
using Type = typename Base::Type;
// Ctor.
template< typename Params >
inline __device__ Gmem_tile_mma_s(void *ptr, const Params &params, const int tidx)
: Base(ptr, params, tidx) {
}
// Store to global memory.
template<typename Mask>
inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
float tmp00 = softmax[2 * mi + 0][4 * ni + 0];
float tmp01 = softmax[2 * mi + 0][4 * ni + 1];
float tmp02 = softmax[2 * mi + 0][4 * ni + 2];
float tmp03 = softmax[2 * mi + 0][4 * ni + 3];
float tmp10 = softmax[2 * mi + 1][4 * ni + 0];
float tmp11 = softmax[2 * mi + 1][4 * ni + 1];
float tmp12 = softmax[2 * mi + 1][4 * ni + 2];
float tmp13 = softmax[2 * mi + 1][4 * ni + 3];
uint4 dst;
dst.x = fmha::float2_to_half2(tmp00, tmp01);
dst.y = fmha::float2_to_half2(tmp02, tmp03);
dst.z = fmha::float2_to_half2(tmp10, tmp11);
dst.w = fmha::float2_to_half2(tmp12, tmp13);
if( mask.is_valid(mi, ni, 0, 0) ) {
Base::store(dst, mi, ni);
}
}
}
}
// Load from global memory.
template<typename Mask>
inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
regs[mi][ni] = make_uint4(0, 0, 0, 0);
if( mask.is_valid(mi, ni, 0, 0) ) {
Base::load(regs[mi][ni], mi, ni);
}
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The base class.
typename Base = fmha::Gmem_tile_qkv<Cta_tile, fmha::BITS_PER_ELEMENT_A, Cta_tile::M, Cta_tile::K>
>
struct Gmem_tile_dout : public Base {
// Ctor.
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_dout(const Params &params, const BInfo &binfo, int tidx)
: Base(params, 0, binfo, tidx) {
this->qkv_ptr_ = reinterpret_cast<char *>(params.o_ptr);
this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes; // needed for move
// Compute the position of the thread in the row.
int col = tidx % Base::THREADS_PER_ROW;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
// Assemble the final pointer.
this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, typename Base = fmha::Gmem_tile_o<Cta_tile> >
struct Gmem_tile_dq : public Base {
// Ctor.
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_dq(const Params &params, const BInfo &binfo, int tidx)
: Base(params, binfo, tidx) {
this->o_ptr_ = reinterpret_cast<char *>(params.dqkv_ptr);
this->params_o_stride_in_bytes_ = params.qkv_stride_in_bytes; // needed for move
// Compute the position of the thread in the row.
int col = tidx % Base::THREADS_PER_ROW;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes +
(binfo.sum_s * 3 * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
// Assemble the final pointer.
this->o_ptr_ += row_offset + col * Base::BYTES_PER_STG;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x8u>
struct FMHA_kernel_traits {
// The CTA description for the 1st GEMM.
using Cta_tile_p = fmha::Cta_tile_extd<STEP, S, D, WARPS_M, WARPS_N, 1>;
// The CTA description for the 2nd GEMM.
using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;
// Do we use one buffer for K and V.
enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x8u) != 0u };
// The global memory tile to load Q.
using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
// The shared memory tile to swizzle Q.
using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;
// The global memory tile to load K.
using Gmem_tile_k = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_B, S, D>;
// The shared memory tile to swizzle K.
using Smem_tile_k = fmha::Smem_tile_b<Cta_tile_p, fmha::Col>;
// The global memory tile to load V.
using Gmem_tile_v = fmha::Gmem_tile_qkv<Cta_tile_o, fmha::BITS_PER_ELEMENT_B, S, D>;
// The shared memory tile to swizzle V.
using Smem_tile_v = fmha::Smem_tile_v<Cta_tile_o>;
// The global memory tile to store O.
using Gmem_tile_o = fmha::Gmem_tile_o<Cta_tile_o>;
// The shared memory tile for O.
using Smem_tile_o = fmha::Smem_tile_o<Cta_tile_o>;
// The global memory tile to load/store S.
using Gmem_tile_s = fmha::Gmem_tile_mma_s<Cta_tile_p>;
// The shared memory tile to transpose S.
using Smem_tile_st = fmha::Smem_tile_mma_transposed<Cta_tile_p>;
using Gmem_tile_do = fmha::Gmem_tile_dout<Cta_tile_p>;
// Make sure the number of threads match.
static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, "");
// The number of threads.
enum { THREADS = Cta_tile_p::THREADS_PER_CTA };
// Make sure the number of threads matches both CTAs.
static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, "");
// The amount of shared memory needed to load Q and K.
enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE };
// The extra amount of shared memory needed to load V.
enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE };
// The amount of shared memory needed for Q, K and V..
enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V };
// The amount of shared memory needed to load Q and store O.
enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE };
// The amount of shared memory needed for Q, K, V and O.
enum { BYTES_PER_SMEM = fmha::Max<BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO>::VALUE };
// Make sure we have enough shared memory.
static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, "");
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace fmha {
template<typename Cta_tile>
struct Mask {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
template<typename Params, typename BInfo>
__device__ Mask(const Params &params, const BInfo &blockInfo, int tidx) {
actual_seqlen = blockInfo.actual_seqlen;
const int warp = tidx / Cta_tile::THREADS_PER_WARP;
const int lane = tidx % Cta_tile::THREADS_PER_WARP;
static_assert(Cta_tile::WARPS_K == 1, "");
// find the warp in the Cta tile
const int warp_n = (warp / Cta_tile::WARPS_M);
const int warp_m = (warp % Cta_tile::WARPS_M);
// decompose warp into 8x4 tile
const int quad = lane / 4;
const int tid = (lane % 4) * 2;
row = warp_m * 16 + quad;
col = warp_n * 16 + tid;
}
inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {
// ii and jj iterate over the 2x4 fragment
const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen;
return col_valid;
// return row_valid && col_valid;
}
inline __device__ void load(int it) {
row_offset = it * Cta_tile::M + row;
}
int row_offset;
int row;
int col;
int actual_seqlen;
};
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <fmha/utils.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The description of the tile computed by this CTA.
typename Cta_tile,
// The number of rows in the 2D shared memory buffer.
int M_,
// The number of cols.
int N_,
// The size in bits of each element.
int BITS_PER_ELEMENT_,
// The number of bytes per STS.
int BYTES_PER_STS_ = 16,
// The number of buffers. (Used in multistage and double buffer cases.)
int BUFFERS_PER_TILE_ = 1,
// Do we enable the fast path for LDS.128 and friends.
int ENABLE_LDS_FAST_PATH_ = 0,
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
int ROWS_PER_XOR_PATTERN_ = 8,
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
int COLS_PER_XOR_PATTERN_ = 1,
// Use or not predicates
bool USE_PREDICATES_ = true
>
struct Smem_tile_without_skews {
// The size in bits of each element.
enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ };
// The size in bytes of a single STS.
enum { BYTES_PER_STS = BYTES_PER_STS_ };
// The number of elements per STS.
enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT };
// To support arbitrary N, we pad some values to a power-of-2.
enum { N_WITH_PADDING = Next_power_of_two<N_>::VALUE };
// The number of bytes per row without packing of rows.
enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 };
// The number of bytes per row -- we want at least 128B per row.
enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE };
// The number of rows in shared memory (two rows may be packed into a single one).
enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW };
// The number of threads per row.
enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS };
// The number of threads per row.
enum { THREADS_PER_ROW = Min<Cta_tile::THREADS_PER_CTA, THREADS_PER_ROW_UNBOUNDED>::VALUE };
// The number of STS per row.
enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS };
// It must be at least one.
static_assert(STS_PER_ROW >= 1, "");
// The number of rows written with a single STS.
enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
// Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)
static_assert(ROWS_PER_STS >= 1, "");
// The number of STS needed to store all rows.
enum { STS_PER_COL = Div_up<ROWS, ROWS_PER_STS>::VALUE };
// The number of STS in total.
enum { STS = STS_PER_COL * STS_PER_ROW };
// The size of one buffer in bytes in shared memory.
enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };
// The number of buffers.
enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ };
// The size in bytes of total buffers.
enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE };
// The boundary for smem_read_offset and smem_write_offset increment.
enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER };
// Do we enable the LDS.128 fast path?
enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ };
static_assert(ENABLE_LDS_FAST_PATH == 0);
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ };
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS };
// Use or not predicates
enum { USE_PREDICATES = USE_PREDICATES_ };
// The type of elements that are stored in shared memory by each thread.
using Store_type = typename Uint_from_size_in_bytes<BYTES_PER_STS>::Type;
// Ctor.
inline __device__ Smem_tile_without_skews(void *smem, int tidx)
: smem_(__nvvm_get_smem_pointer(smem)) {
// The row written by a thread. See doc/mma_smem_layout.xlsx.
int smem_write_row = tidx / THREADS_PER_ROW;
// The XOR pattern.
int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN;
// Compute the column and apply the XOR pattern.
int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor;
// The offset.
this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;
// TODO: Why not merge it with the read offset?
this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);
this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);
}
// Compute the store pointers.
template< int N >
inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) {
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
// Decompose the STS into row/col.
int row = ii / STS_PER_ROW;
int col = ii % STS_PER_ROW;
// Assemble the offset.
int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW;
// Take the column into account.
if( STS_PER_ROW > 1 ) {
offset += col*THREADS_PER_ROW*BYTES_PER_STS;
}
// Apply the XOR pattern if needed.
if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) {
const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;
offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;
}
// Assemble the final pointer :)
ptrs[ii] = smem_ + offset + smem_write_buffer_;
}
}
inline __device__ void debug_reset() {
for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {
for( int row = 0; row < ROWS; ++row ) {
for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {
if( threadIdx.x == 0 ) {
uint32_t val = 0x0;
sts(val, smem_ + row*BYTES_PER_ROW + col + buffer);
}
}
}
}
}
// Print the content of the tile (only for debug ;)).
inline __device__ void debug_print() const {
for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {
for( int row = 0; row < ROWS; ++row ) {
for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {
if( threadIdx.x == 0 ) {
uint32_t val;
lds(val, smem_ + row*BYTES_PER_ROW + col + buffer);
printf("block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n",
blockIdx.x,
blockIdx.y,
blockIdx.z,
smem_,
buffer,
row,
col,
val);
}
}
}
}
}
// Move the read offset to next buffer.
inline __device__ void move_to_next_read_buffer() {
if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_read_buffer_ += BYTES_PER_BUFFER;
}
}
// Move the read offset to next buffer. TODO: Remove this member function!!!
inline __device__ void move_next_read_buffer() {
this->move_to_next_read_buffer();
}
// Move the read offset to next N buffer (circular-buffer).
inline __device__ void move_to_next_read_buffer(int N) {
if( BUFFERS_PER_TILE > 1 ) {
this->smem_read_buffer_ += N * BYTES_PER_BUFFER;
this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
}
}
// Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!
inline __device__ void move_next_read_buffer(int N) {
this->move_to_next_read_buffer(N);
}
// Move the write offset to next buffer.
inline __device__ void move_to_next_write_buffer() {
if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_write_buffer_ += BYTES_PER_BUFFER;
}
}
// Move the write offset to next buffer. TODO: Remove that member function!
inline __device__ void move_next_write_buffer() {
this->move_to_next_write_buffer();
}
// Move the read offset.
inline __device__ void move_read_offset(int delta) {
this->smem_read_offset_ += delta;
}
// Move the write offset.
inline __device__ void move_write_offset(int delta) {
this->smem_write_offset_ += delta;
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) {
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
sts(smem_ptrs, data);
}
// Store to the tile in shared memory.
template< int N, int M >
inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) {
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
sts(smem_ptrs, data, preds);
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) {
this->store(data, preds);
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) {
uint32_t tmp[1] = { preds };
this->store(gmem_ptrs, tmp);
}
// The shared memory pointer.
uint32_t smem_;
// The read offset. Reserve 4 offsets if needed.
int smem_read_offset_;
// The write offset.
int smem_write_offset_;
// The buffer base offset for read.
int smem_read_buffer_;
// The buffer base offset for write.
int smem_write_buffer_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The layout of the tile.
typename Layout,
// The size of the STS.
int BYTES_PER_STS = 16,
// The number of buffers per tile.
int BUFFERS_PER_TILE = 1,
// Use or not predicates
bool USE_PREDICATES = true
>
struct Smem_tile_a {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K, int MMAS_K_WITH_PADDING >
struct Compute_reset_mask {
// The potential mask.
enum { HALF = MMAS_K_WITH_PADDING / 2 };
// The remainder.
enum { MOD = MMAS_K % HALF };
// The final value.
enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask<MOD, HALF>::VALUE };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K_WITH_PADDING >
struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> {
enum { VALUE = 0 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K >
struct Compute_reset_mask<MMAS_K, MMAS_K> {
enum { VALUE = MMAS_K - 1 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_a {
// The size in bits.
enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A };
// The number of rows.
enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a<N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a<Cta_tile::K>::VALUE
>
struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
Cta_tile::M,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_A,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::M,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_A,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1>;
// The fragment.
using Fragment = Fragment_a<Row>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// Ctor.
inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) {
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_M == 1);
static_assert(WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_K == 1);
static_assert(Base::ROWS_PER_XOR_PATTERN == 8);
// The row and column read by the thread.
int smem_read_row = (tidx & 0x0f);
int smem_read_col = (tidx & 0x07);
smem_read_col ^= (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) {
#pragma unroll
for( int mi = 0; mi < Mma_tile::MMAS_M; ++mi ) {
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// Load using LDSM.M88.4.
uint4 tmp;
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
// Store the value into the fragment.
a[mi].reg(0) = tmp.x;
a[mi].reg(1) = tmp.y;
a[mi].reg(2) = tmp.z;
a[mi].reg(3) = tmp.w;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) {
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) {
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) {
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset() {
// The number of MMAs in the K dimension.
enum { MMAS_K = Mma_tile::MMAS_K };
// The number of MMAs in the K dimension when we include padding.
enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };
// Assemble the mask.
enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_a<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_row_a<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_row_a<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The layout of the tile.
typename Layout,
// The size of the STS.
int BYTES_PER_STS = 16,
// The number of buffers per tile.
int BUFFERS_PER_TILE = 1,
// Use or not predicates
bool USE_PREDICATES = true
>
struct Smem_tile_b {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_b {
// The size in bits.
enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B };
// The number of rows.
enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b<N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b<Cta_tile::K>::VALUE
>
struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
Cta_tile::N,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::N,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1>;
// The fragment.
using Fragment = Fragment_b< Col>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding< Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// The number of STS per thread
enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };
// The number of STS per thread must be at least 1.
enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };
// Ctor.
inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) {
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(Base::ROWS_PER_XOR_PATTERN == 8);
static_assert(WARPS_M == 1);
static_assert(WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_K == 1);
// The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
// The divisor for the warps.
const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
// The row and column read by the thread.
int smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA +
(tidx & 0x07) +
(tidx & 0x10) / 2;
int smem_read_col = (tidx & 0x07);
smem_read_col ^= (tidx & 0x08) / 8;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// Load using LDSM.M88.4.
uint4 tmp;
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
// Store the value into the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) {
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) {
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) {
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset() {
// The number of MMAs in the K dimension.
enum { MMAS_K = Mma_tile::MMAS_K };
// The number of MMAs in the K dimension when we include padding.
enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };
// Assemble the mask.
enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_b< Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE >
: public Smem_tile_col_b<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_col_b< Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b< N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b<Cta_tile::N>::VALUE,
// How many cols to use for the XOR pattern to avoid bank conflicts?
int COLS_PER_XOR_PATTERN_ = 1
>
struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
Cta_tile::K,
Cta_tile::N,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
COLS_PER_XOR_PATTERN_> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::K,
Cta_tile::N,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
COLS_PER_XOR_PATTERN_>;
// The fragment.
using Fragment = Fragment_b<Row>;
// Can we use LDSM? No if the data type is 32-bit large.
enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 };
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 };
// The number of elements per LDS.
enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B };
// The number of STS per thread
enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };
// The number of STS per thread must be at least 1.
enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };
// Ctor.
inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) {
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_K == 1);
static_assert(WARPS_M == 4 || WARPS_M == 8);
static_assert(WARPS_N == 1);
// The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
const int WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
// The row/col read by the thread.
int smem_read_row, smem_read_col;
static_assert(USE_LDSMT);
static_assert(Base::ROWS_PER_XOR_PATTERN == 8);
smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 +
(tidx & 0x07) + (tidx & 0x08);
smem_read_col = (tidx & 0x07);
smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;
// Fill zeroes for group conv
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// The size of each element in bits.
const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( BYTES_PER_MMA_PER_CTA >= 128 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
} else if( BYTES_PER_MMA_PER_CTA == 64 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
Mma_tile::MMAS_N % 2 == 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
// The size of each element in bits.
const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Prepare the offset.
int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW;
if ( BYTES_PER_MMA_PER_CTA == 32 ) {
offset += this->smem_read_offset_;
} else if ( BYTES_PER_MMA_PER_CTA == 64 ) {
offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2;
} else {
offset += this->smem_read_offset_ + (ni ) * BYTES_PER_MMA_PER_CTA;
}
// Load the data using LDSM.MT88.2.
uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;
uint4 tmp;
if( USE_LDSMT ) {
ldsmt(tmp, ptr);
} else {
lds(tmp.x, (ptr ) + 0*Base::BYTES_PER_ROW);
lds(tmp.y, (ptr ) + 4*Base::BYTES_PER_ROW);
lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW);
lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW);
}
// Store those values in the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if( BYTES_PER_MMA_PER_CTA >= 128 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
} else if( BYTES_PER_MMA_PER_CTA == 64 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
Mma_tile::MMAS_N % 2 == 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_b<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_row_b<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_row_b<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, 8, 1> {
// The base class.
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, 8, 1>;
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The fragment.
using Fragment = Fragment_b< fmha::Col>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// Ctor.
inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) {
// The row/col read by the thread.
int read_row, read_col;
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f);
read_col = (tidx & 0x07);
read_col ^= (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Jump by 16 * #warps row.
int row = ki * 16 * Cta_tile::WARPS_K;
// Load the data using LDSM.MT88.2.
uint4 tmp;
fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW);
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if( Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else {
assert(false); // Not implemented!
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile>
struct Smem_tile_o {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The accumulators.
using Accumulator = fmha::Fragment_accumulator;
// The accumulators.
using Data_type = typename Accumulator::Data_type;
// The size of each element.
enum { BYTES_PER_ELEMENT = sizeof(Data_type) };
// The size of each STS.
enum { BYTES_PER_STS = 8 };
// The size of each row in shared memory.
enum { BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT };
// The size of each LDS.
enum { BYTES_PER_LDS = 16 };
enum { THREADS_PER_ROW = 16 };
// The number of rows.
enum { ROWS = Cta_tile::M };
// The number of "rows" to process per loop iteration (in the "epilogue").
enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA };
// The number of outer loops.
enum { LOOPS = ROWS / ROWS_PER_LOOP };
// Make sure it matches our expectations.
static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, "");
// The number of rows loaded per LDS.
enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
// Do we have to guard against partial writes/reads.
enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 };
// The total number of LDS per loop.
enum { LDS_PER_LOOP = fmha::Div_up<ROWS_PER_LOOP, ROWS_PER_LDS>::VALUE };
// The amount of shared memory.
enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW };
// The write pointer.
uint32_t smem_write_, smem_read_;
// Is the thread active for the last LDS of the series?
int is_active_for_last_lds_;
static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K);
static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, "");
// Ctor.
inline __device__ Smem_tile_o(void *smem, int tidx) {
// Get a 32-bit value for the shared memory address.
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
int write_row = (tidx & 0x1c) / 4;
int write_col = (tidx);
// Assemble the write pointer.
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
// The element read by each thread.
int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
// Take the XOR pattern into account for the column.
read_col ^= 2 * (read_row & 0x7);
// Assemble the read pointer.
this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// Is that thread active on the last LDS?
if( HAS_INCOMPLETE_LDS ) {
this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M;
}
}
// Load the output fragments.
inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {
#pragma unroll
for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {
// Load the elements before the reduction (split-K).
uint4 tmp[Cta_tile::WARPS_K];
#pragma unroll
for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {
fmha::lds(tmp[jj], this->smem_read_ + imm);
}
}
// Perform the reduction.
out[ii] = tmp[0];
#pragma unroll
for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {
out[ii] = fmha::fadd4(out[ii], tmp[jj]);
}
}
}
// Store the accumulators.
template <int M, int N>
inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {
enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA };
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// The number of MMAs that are stored per loop iteration.
enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS };
// Store 1st column of the different MMAs.
#pragma unroll
for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
// Precompute the immediates to jump between rows.
int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
uint2 tmp0, tmp1;
tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0);
tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1);
tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2);
tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3);
// Store.
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// Swizzle the write pointer using a XOR of 16B.
this->smem_write_ ^= 32;
// Store 2nd column of the different MMAs.
#pragma unroll
for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
// Precompute the immediates to jump between rows.
int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
uint2 tmp0, tmp1;
tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4);
tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5);
tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6);
tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7);
// Store.
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_mma {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
using Fragment = fmha::Fragment_a<fmha::Col>;
enum { COLS = Cta_tile::N };
enum { BYTES_PER_ELT = 2 };
enum { BYTES_PER_STS = 4 };
enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO
enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
enum { WARPS_K = Cta_tile::WARPS_K };
static_assert(WARPS_K == 1);
inline __device__ Smem_tile_mma(char *smem, int tidx) {
smem_ = __nvvm_get_smem_pointer(smem);
int write_col, write_row;
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {
write_row = (tidx & 0x1c) / 4;
write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
} else {
write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;
write_col = (tidx & 0x03);
}
write_col ^= (write_row & 0x07) * 4;
write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
}
template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) {
static_assert(COLS == Cta_tile::N);
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
}
}
}
uint32_t smem_;
uint32_t write_offset_;
uint32_t warp_m;
uint32_t warp_n;
uint32_t lane;
};
template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>
struct Smem_tile_mma_transposed : public Base {
enum { BYTES_PER_LDS = 16 };
enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };
enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
using Fragment = typename Base::Fragment;
inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) {
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
int read_row, read_col;
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;
read_col ^= (read_row & 0x07);
read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
template<int M, int N>
inline __device__ void load(Fragment (&frag)[M][N]) {
static_assert(Base::COLS == Cta_tile::N);
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
frag[mi][ni].reg(0) = dst.x;
frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major!
frag[mi][ni].reg(2) = dst.y;
frag[mi][ni].reg(3) = dst.w;
}
}
}
uint32_t read_offset_;
};
template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>
struct Smem_tile_mma_epilogue : public Base {
enum { BYTES_PER_LDS = 16 };
enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };
enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS };
static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW);
enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS };
static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M);
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
using Fragment = typename Base::Fragment;
inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) {
const int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
read_col ^= (read_row & 0x07);
read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
inline __device__ void load(uint4 (&data)[NUM_LDS]) {
for( int ii = 0; ii < NUM_LDS; ii++ ) {
size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
fmha::lds(data[ii], this->smem_ + offset);
}
}
template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) {
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
}
}
}
uint32_t read_offset_;
};
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Sum_ {
enum { IS_SUM = 1 };
static inline __device__ float apply(float x, float y) {
return x + y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Max_ {
enum { IS_SUM = 0 };
static inline __device__ float apply(float x, float y) {
return x > y ? x : y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float apply_exp_(float x, float max) {
return __expf(x - max);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile, typename Kernel_traits>
struct Softmax_base {
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
enum { MMAS_M = Mma_tile::MMAS_M };
enum { MMAS_N = Mma_tile::MMAS_N };
// The number of groups of warp such that we have at most 4 warps writing consecutive elements.
enum { GROUPS = fmha::Div_up<Cta_tile::WARPS_N, 4>::VALUE };
// The number of elements that we are going to store per row.
enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS };
// The number of rows.
enum { ROWS = Cta_tile::M * GROUPS };
// The total number of elements.
enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW };
// Ctor.
template<typename Params>
inline __device__ Softmax_base(const Params &params, void *smem, int bidb, int tidx)
: // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
smem_(reinterpret_cast<float *>(smem)), tidx_(tidx) {
// Move to the 1st mask loaded by the thread+ tidx;
// packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);
// Extract the position in the warp.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// Decompose the warp index into M and N.
int warp_m = warp % Cta_tile::WARPS_M;
int warp_n = warp / Cta_tile::WARPS_M;
// Decompose the warp-n index into group/position-inside-the-group.
int warp_g = warp_n / ELEMENTS_PER_ROW;
int warp_i = warp_n % ELEMENTS_PER_ROW;
// The location written by the threads.
int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;
int write_col = warp_i;
// Assemble the write pointer.
smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];
// Assemble the read pointer.
smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
}
template<typename Mask>
inline __device__ void apply_mask(const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ii = 0; ii < 2; ++ii ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
#pragma unroll
for( int jj = 0; jj < 4; ++jj ) {
if( !mask.is_valid(mi, ni, ii, jj) ) {
elt_[2 * mi + ii][4 * ni + jj] = -INFINITY;
}
}
}
}
}
}
// Apply the exp to all the elements.
inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);
}
}
}
// Do a CTA-wide reduction.
template<typename Functor>
inline __device__ void reduce_1x4(float (&dst)[MMAS_M * 2]) {
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if( Functor::IS_SUM ) {
// Apply the summation inside the thread.
float tmp[MMAS_M * 2][2];
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
tmp[mi][0] = 0.f;
tmp[mi][1] = 0.f;
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
tmp[mi][0] += elt_[mi][4 * ni + 0];
tmp[mi][0] += elt_[mi][4 * ni + 1];
tmp[mi][1] += elt_[mi][4 * ni + 2];
tmp[mi][1] += elt_[mi][4 * ni + 3];
}
dst[mi] = tmp[mi][0] + tmp[mi][1];
}
} else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = elt_[mi][0];
#pragma unroll
for( int ni = 1; ni < MMAS_N * 4; ++ni ) {
dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1));
__syncwarp();
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2));
__syncwarp();
}
// Store the different values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
if( tidx_ % 4 == 0 ) {
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0];
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1];
}
}
// Make sure the values are in shared memory.
__syncthreads();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4 tmp[1];
if( tidx_ < Cta_tile::M ) {
tmp[0] = reinterpret_cast<const float4 *>(&smem_[0 * ELEMENTS / 2])[tidx_];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y);
tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w);
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z);
// Make sure we can write to shared memory.
__syncthreads();
// Store the value back to shared memory.
if( tidx_ < Cta_tile::M ) {
smem_[tidx_] = tmp[0].x;
}
// Make sure the data is in shared memory.
__syncthreads();
// Finally read the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0];
dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8];
}
}
// Do a CTA-wide reduction.
template<typename Functor>
inline __device__ void reduce_1x8(float (&dst)[MMAS_M * 2]) {
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if( Functor::IS_SUM ) {
// Apply the summation inside the thread.
float tmp[MMAS_M * 2][2];
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
tmp[mi][0] = 0.f;
tmp[mi][1] = 0.f;
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
tmp[mi][0] += elt_[mi][4 * ni + 0];
tmp[mi][0] += elt_[mi][4 * ni + 1];
tmp[mi][1] += elt_[mi][4 * ni + 2];
tmp[mi][1] += elt_[mi][4 * ni + 3];
}
dst[mi] = tmp[mi][0] + tmp[mi][1];
}
} else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = elt_[mi][0];
#pragma unroll
for( int ni = 1; ni < MMAS_N * 4; ++ni ) {
dst[mi] = Functor::apply(dst[mi], elt_[mi][ni]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 1));
__syncwarp();
dst[mi] = Functor::apply(dst[mi], __shfl_xor_sync(uint32_t(-1), dst[mi], 2));
__syncwarp();
}
// Store the different values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
if( tidx_ % 4 == 0 ) {
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 0) * ELEMENTS_PER_ROW] = dst[2 * mi + 0];
smem_write_[(mi * Mma_tile::M_PER_MMA_PER_CTA + 8) * ELEMENTS_PER_ROW] = dst[2 * mi + 1];
}
}
// Make sure the values are in shared memory.
__syncthreads();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4 tmp[2];
if( tidx_ < Cta_tile::M ) {
tmp[0] = reinterpret_cast<const float4 *>(&smem_[0 * ELEMENTS / 2])[tidx_];
tmp[1] = reinterpret_cast<const float4 *>(&smem_[1 * ELEMENTS / 2])[tidx_];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].y);
tmp[0].z = Functor::apply(tmp[0].z, tmp[0].w);
tmp[1].x = Functor::apply(tmp[1].x, tmp[1].y);
tmp[1].z = Functor::apply(tmp[1].z, tmp[1].w);
tmp[0].x = Functor::apply(tmp[0].x, tmp[0].z);
tmp[1].x = Functor::apply(tmp[1].x, tmp[1].z);
tmp[0].x = Functor::apply(tmp[0].x, tmp[1].x);
// Make sure we can write to shared memory.
__syncthreads();
// Store the value back to shared memory.
if( tidx_ < Cta_tile::M ) {
smem_[tidx_] = tmp[0].x;
}
// Make sure the data is in shared memory.
__syncthreads();
// Finally read the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
dst[2 * mi + 0] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 0];
dst[2 * mi + 1] = smem_read_[mi * Mma_tile::M_PER_MMA_PER_CTA + 8];
}
}
// Do a CTA-wide reduction.
template<typename Functor>
inline __device__ void reduce(float (&dst)[MMAS_M * 2]) {
static_assert(Cta_tile::WARPS_M == 1 && (Cta_tile::WARPS_N == 4 || Cta_tile::WARPS_N == 8));
if( Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 4 ) {
reduce_1x4<Functor>(dst);
} else if( Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 8 ) {
reduce_1x8<Functor>(dst);
} else {
assert(false);
}
// Make sure we are done reading from shared memory.
__syncthreads();
}
// Scale all the elements.
inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
float inv_sum[MMAS_M * 2];
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];
}
// Update the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] *= inv_sum[mi];
}
}
}
// The pointer to the mask.
const char *packed_mask_ptr_;
// Shared memory for the CTA-wide reduction.
float *smem_, *smem_write_, *smem_read_;
// The current thread index.
int tidx_;
// The elements.
float elt_[MMAS_M * 2][MMAS_N * 4];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile, typename Kernel_traits>
struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
// The base class.
using Base = Softmax_base<Cta_tile, Kernel_traits>;
// The fragment.
using Fragment_a = fmha::Fragment_a<fmha::Row>;
static_assert(Fragment_a::NUM_REGS == 4);
// The MMAs.
enum { MMAS_M = Base::MMAS_M };
enum { MMAS_N = Base::MMAS_N };
// The accumulators.
using Accumulator = fmha::Fragment_accumulator;
using Accumulator_out = Fragment<uint16_t, 8>;
static_assert(Accumulator_out::NUM_REGS == 4);
static_assert(std::is_same<Accumulator::Data_type, float>::value);
// Ctor.
template<typename Params>
inline __device__ Softmax(const Params &params, void *smem, int bidb, int tidx)
: Base(params, smem, bidb, tidx), params_scale_bmm1_(params.scale_bmm1) {
}
// Store the tile after softmax.
template<typename Gmem_tile>
inline __device__ void store(Gmem_tile &gmem_tile) {
Accumulator_out acc[MMAS_M][MMAS_N];
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// The elements.
float tmp_00 = this->elt_[2 * mi + 0][4 * ni + 0];
float tmp_01 = this->elt_[2 * mi + 0][4 * ni + 1];
float tmp_02 = this->elt_[2 * mi + 0][4 * ni + 2];
float tmp_03 = this->elt_[2 * mi + 0][4 * ni + 3];
float tmp_10 = this->elt_[2 * mi + 1][4 * ni + 0];
float tmp_11 = this->elt_[2 * mi + 1][4 * ni + 1];
float tmp_12 = this->elt_[2 * mi + 1][4 * ni + 2];
float tmp_13 = this->elt_[2 * mi + 1][4 * ni + 3];
// Transform to accumulators.
acc[mi][ni].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);
acc[mi][ni].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);
acc[mi][ni].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);
acc[mi][ni].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);
}
}
// Delegate to the gmem tile to store.
gmem_tile.store(acc);
}
// Pack the data to a fragment for the next GEMM.
template<int K, int M>
inline __device__ void pack(Fragment_a (&dst)[K][M]) const {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ki = 0; ki < K; ++ki ) {
// 1st row - 4 elements per row.
float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];
float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];
float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2];
float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3];
// 2nd row - 4 elements per row.
float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0];
float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1];
float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2];
float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];
// Pack to 4 registers.
dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);
dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);
dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);
dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);
}
}
}
// Scale FP32 fragments
inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
const float scalef = reinterpret_cast<const float &>(this->params_scale_bmm1_);
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef;
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef;
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef;
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef;
// 2nd row - 4 elements per row.
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef;
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef;
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef;
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef;
}
}
}
const uint32_t params_scale_bmm1_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr);
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Row {};
struct Col {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, bool = (M & (M-1)) == 0 >
struct Next_power_of_two {
};
template< int M >
struct Next_power_of_two< M, true > { enum { VALUE = M }; };
template<>
struct Next_power_of_two< 3, false> { enum { VALUE = 4 }; };
template<>
struct Next_power_of_two< 5, false> { enum { VALUE = 8 }; };
template<>
struct Next_power_of_two< 6, false> { enum { VALUE = 8 }; };
template<>
struct Next_power_of_two< 7, false> { enum { VALUE = 8 }; };
template<>
struct Next_power_of_two< 9, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 10, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 11, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 12, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 13, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 14, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 15, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 24, false> { enum { VALUE = 32 }; };
template<>
struct Next_power_of_two< 48, false> { enum { VALUE = 64 }; };
template<>
struct Next_power_of_two< 80, false> { enum { VALUE = 128 }; };
template<>
struct Next_power_of_two< 96, false> { enum { VALUE = 128 }; };
template<>
struct Next_power_of_two<112, false> { enum { VALUE = 128 }; };
template<>
struct Next_power_of_two<144, false> { enum { VALUE = 256 }; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, bool = (N & (N-1)) == 0 >
struct Prev_power_of_two {
};
template< int N >
struct Prev_power_of_two< N, true > { enum { VALUE = N }; };
template<>
struct Prev_power_of_two< 3, false> { enum { VALUE = 2 }; };
template<>
struct Prev_power_of_two< 5, false> { enum { VALUE = 4 }; };
template<>
struct Prev_power_of_two< 6, false> { enum { VALUE = 4 }; };
template<>
struct Prev_power_of_two< 7, false> { enum { VALUE = 4 }; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, int N >
struct Div_up {
enum { VALUE = (M + N-1) / N };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B >
struct Max {
enum { VALUE = A >= B ? A : B };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B, int C >
struct Max_3 {
enum { VALUE = Max<Max<A, B>::VALUE, C>::VALUE };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B >
struct Min {
enum { VALUE = A <= B ? A : B };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int SIZE_IN_BYTES >
struct Uint_from_size_in_bytes {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<1> {
using Type = uint8_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<2> {
using Type = uint16_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<4> {
using Type = uint32_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<8> {
using Type = uint2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<16> {
using Type = uint4;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int WARPS_M, int WARPS_N, int WARPS_K >
struct Warp_masks {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; };
template<>
struct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; };
template<>
struct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; };
template<>
struct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; };
template<>
struct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; };
template<>
struct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; };
template<>
struct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; };
template<>
struct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; };
template<>
struct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; };
template<>
struct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; };
template<>
struct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; };
template<>
struct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; };
template<>
struct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; };
template<>
struct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; };
template<>
struct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; };
template<>
struct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; };
template<>
struct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename T >
inline __device__ __host__ T div_up(T m, T n) {
return (m + n-1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int clz(int x) {
for( int i = 31; i >= 0; --i ) {
if( (1 << i) & x ) {
return 31 - i;
}
}
return 32;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int find_log_2(int x, bool round_up = false) {
int a = 31 - clz(x);
if( round_up ) {
a += (x & (x-1)) ? 1 : 0;
}
return a;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hmul4(uint2 a, uint2 b) {
uint2 c;
c.x = hmul2(a.x, b.x);
c.y = hmul2(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hmul8(uint4 a, uint4 b) {
uint4 c;
c.x = hmul2(a.x, b.x);
c.y = hmul2(a.y, b.y);
c.z = hmul2(a.z, b.z);
c.w = hmul2(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hmul8(uint32_t a, uint4 b) {
uint4 c;
c.x = hmul2(a, b.x);
c.y = hmul2(a, b.y);
c.z = hmul2(a, b.z);
c.w = hmul2(a, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
uint32_t res;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb));
#else
const uint32_t zero = 0u;
asm volatile( \
"{\n" \
"\t .reg .f16x2 sela;\n" \
"\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
"\t and.b32 %0, sela, %1;\n"
"}\n" : "=r"(res) : "r"(x), "r"(zero));
#endif
return res;
}
static inline __device__ uint32_t habs2(uint32_t x) {
uint32_t res;
asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x));
return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
template< typename T >
static inline __device__ T clamp(T x, T lb, T ub) {
return x < lb ? lb : (x > ub ? ub : x);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t clamp_to_zero(uint16_t x) {
uint16_t mask;
asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x));
return mask & x;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t float_to_half(float f) {
uint16_t h;
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f));
return h;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float2_to_half2(float a, float b) {
uint32_t c;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a));
#else
uint16_t lo = float_to_half(a);
uint16_t hi = float_to_half(b);
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi));
#endif
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float_to_half2(float a) {
return float2_to_half2(a,a);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float2_to_half2(const float2 &f) {
return float2_to_half2(f.x, f.y);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) {
uint2 d;
d.x = float2_to_half2(x, y);
d.y = float2_to_half2(z, w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c));
#else
d = hrelu2(hfma2(a, b, c));
#endif
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t h0_h0(uint32_t x) {
uint32_t y;
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n"
: "=r"(y) : "r"(x));
return y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float h0_to_float(uint32_t h2) {
float f;
asm volatile("{\n" \
".reg .f16 lo, hi;\n" \
"mov.b32 {lo, hi}, %1;\n" \
"cvt.f32.f16 %0, lo;\n" \
"}\n" : "=f"(f) : "r"(h2));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t h1_h1(uint32_t x) {
uint32_t y;
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n"
: "=r"(y) : "r"(x));
return y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) {
uint16_t d;
asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) {
return hadd2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hadd4(uint2 a, uint2 b) {
uint2 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hadd(uint2 a, uint2 b) {
return hadd4(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hadd8(uint4 a, uint4 b) {
uint4 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
c.z = hadd2(a.z, b.z);
c.w = hadd2(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 fadd4(uint4 a, uint4 b) {
float4 c;
c.x = reinterpret_cast<const float&>(a.x) + reinterpret_cast<const float&>(b.x);
c.y = reinterpret_cast<const float&>(a.y) + reinterpret_cast<const float&>(b.y);
c.z = reinterpret_cast<const float&>(a.z) + reinterpret_cast<const float&>(b.z);
c.w = reinterpret_cast<const float&>(a.w) + reinterpret_cast<const float&>(b.w);
return reinterpret_cast<const uint4&>(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hadd(uint4 a, uint4 b) {
return hadd8(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float half_to_float(uint16_t h) {
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float2 half2_to_float2(uint32_t x) {
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x));
return make_float2(half_to_float(lo), half_to_float(hi));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void half2_to_float2(float &x, float &y, uint32_t h) {
float2 tmp = half2_to_float2(h);
x = tmp.x;
y = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) {
uint16_t d;
asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) {
uint16_t d;
asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float sigmoid(float x) {
return 1.f / (1.f + expf(-x));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint16_t &dst) {
dst = uint16_t(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint32_t &dst) {
dst = 0u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint2 &dst) {
dst = make_uint2(0u, 0u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint4 &dst) {
dst = make_uint4(0u, 0u, 0u, 0u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// P R E D I C A T E P A C K I N G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE };
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// G E N E R I C P R E D I C A T E D L D G S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M, typename Functor >
inline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) {
// The number of complete bytes (where we use all the predicates in a byte).
enum { COMPLETE = N / PREDS_PER_BYTE };
// Make sure we did allocate enough predicates.
static_assert(Div_up<COMPLETE, BYTES_PER_REG>::VALUE <= M, "");
// The remainder.
enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE };
// Make sure we got the math right and the remainder is between 0 and 3.
static_assert(REMAINDER >= 0 && REMAINDER <= 3, "");
// The mask to extract the predicates.
enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 };
// Clear the fetch registers.
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
fct.clear(ii);
}
// Run complete steps.
bool p[PREDS_PER_BYTE];
#pragma unroll
for( int ii = 0; ii < COMPLETE; ++ii ) {
// The predicate.
uint32_t reg = preds[ii / BYTES_PER_REG];
// Extract the predicates.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj);
p[jj] = (reg & mask) != 0u;
}
// Issue the loads.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
fct.load(ii * PREDS_PER_BYTE + jj, p[jj]);
}
}
// Skip the rest of the code if we do not have a remainder.
if( REMAINDER > 0 ) {
// The mask to extract the predicates.
enum { REMAINDER_MASK = (1 << REMAINDER) - 1 };
// The predicate register.
uint32_t reg = preds[COMPLETE / BYTES_PER_REG];
// Extract the predicates.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj);
p[jj] = (reg & mask) != 0u;
}
// Issue the loads.
#pragma unroll
for( int ii = 0; ii < REMAINDER; ++ii ) {
fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, typename Functor >
inline __device__ void load_(Functor &fct, uint32_t preds) {
uint32_t tmp[1] = { preds };
load_<M>(fct, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint8_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint8_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint16_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint16_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint32_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint32_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint2 &dst, const void *ptr) {
dst = *reinterpret_cast<const uint2*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint4 &dst, const void *ptr) {
dst = *reinterpret_cast<const uint4*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type, int N >
struct Ldg_functor {
// Ctor.
inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N])
: fetch_(fetch), ptrs_(ptrs) {
}
// Clear the element.
inline __device__ void clear(int ii) {
fmha::clear(fetch_[ii]);
}
// Trigger the loads.
inline __device__ void load(int ii, bool p) {
if( p ) {
ldg(fetch_[ii], ptrs_[ii]);
}
}
// The fetch registers.
Data_type (&fetch_)[N];
// The pointers.
const void* (&ptrs_)[N];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type, int N, int M >
inline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
Ldg_functor<Data_type, N> fct(fetch, ptrs);
load_<N>(fct, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint8_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint16_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint32_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint2, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint4, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint16_t &dst, uint32_t ptr) {
asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint32_t &dst, uint32_t ptr) {
asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint2 &dst, uint32_t ptr) {
asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint4 &dst, uint32_t ptr) {
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x)
, "=r"(dst.y)
, "=r"(dst.z)
, "=r"(dst.w)
: "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S M
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n"
: "=r"(dst) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint32_t &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n"
: "=r"(dst) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint2 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint2 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint4 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint4 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint8_t val) {
*reinterpret_cast<uint8_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint16_t val) {
*reinterpret_cast<uint16_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint32_t val) {
*reinterpret_cast<uint32_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint2 val) {
*reinterpret_cast<uint2*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint4 val) {
*reinterpret_cast<uint4*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint16_t val) {
asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint32_t val) {
asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint2 val) {
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n"
:
: "r"(ptr)
, "r"(val.x)
, "r"(val.y));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint4 val) {
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n"
:
: "r"(ptr)
, "r"(val.x)
, "r"(val.y)
, "r"(val.z)
, "r"(val.w));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type, int N >
inline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) {
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
sts(ptrs[ii], data[ii]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) {
sts_<uint16_t, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) {
sts_<uint32_t, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) {
sts_<uint2, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {
sts_<uint4, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 128 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_128_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_128_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 256 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_256_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_256_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 384 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_384_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_384_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 512 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_512_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_512_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename Params>
inline __device__ void compute_dv_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dv =
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128);
static_assert(Cta_tile_dv::N == 64);
static_assert(Cta_tile_dv::K == 16);
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dv = fmha::Hmma_tile<Cta_tile_dv>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
// using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The shared memory tile to reload Q as fragment b.
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store dV.
using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle dV.
using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dv>;
static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS);
static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW);
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
Smem_tile_qt smem_qt(&smem_[0], tidx);
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];
smem_q.load(frag_q[0], 0);
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N];
static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);
static_assert(Mma_tile_dv::MMAS_K == 1);
smem_qt.load(frag_qt[0], 0);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N];
smem_k.load(frag_k[0], 0);
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
// Create the object to do the softmax.
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
Softmax softmax(
params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx);
enum { THREADS_PER_ROW = 32 };
enum { M = Mma_tile_p::MMAS_M };
enum { N = Mma_tile_p::MMAS_N };
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);
// Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) {
if( loop >= binfo.actual_seqlen )
break;
// Load S
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[ki & 1], ki);
smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Store s * dmask to smem for transpose
smem_s.store(s_regs);
// Declare the accumulators for the 1st gemm.
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
float s_mat[2 * M][4 * N];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 &dst = s_regs[mi][ni];
fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x);
fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y);
fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z);
fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w);
}
}
float d_s[2 * M][4 * N];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
const float s_dmask = s_mat[2 * mi + ii][4 * ni + jj];
const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000;
d_s[2 * mi + ii][4 * ni + jj] = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;
softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s[2 * mi + ii][4 * ni + jj] * fabsf(s_dmask);
}
}
}
}
float p_sum[2 * M];
softmax.template reduce<fmha::Sum_>(p_sum);
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
softmax.elt_[2 * mi + ii][4 * ni + jj] = (d_s[2 * mi + ii][4 * ni + jj] - p_sum[2 * mi + ii]) *
fabsf(s_mat[2 * mi + ii][4 * ni + jj]) * scalef;
}
}
}
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];
smem_s.load(frag_s);
for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) {
for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) {
for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) {
frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout);
frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii));
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
#pragma unroll
for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dv::MMAS_K;
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Commit the values for Q into shared memory.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
gmem_q.commit(smem_q);
}
// Make sure we are reading from the correct buffer.
smem_q.move_to_next_read_buffer();
smem_qt.move_to_next_read_buffer();
// Make sure the data is in shared memory.
__syncthreads();
// Trigger the loads for the values of Q for the next iteration.
smem_q.load(frag_q[0], 0);
smem_k.load(frag_k[0], 0);
smem_qt.load(frag_qt[0], 0);
} // Outer loop over the sequence length.
// Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);
uint4 dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];
#pragma unroll
for( int mi = 0; mi < Mma_tile_dv::MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_dv::MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
float tmp00 = acc_dv[mi][ni].elt(0);
float tmp01 = acc_dv[mi][ni].elt(1);
float tmp02 = acc_dv[mi][ni].elt(4);
float tmp03 = acc_dv[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc_dv[mi][ni].elt(2);
float tmp11 = acc_dv[mi][ni].elt(3);
float tmp12 = acc_dv[mi][ni].elt(6);
float tmp13 = acc_dv[mi][ni].elt(7);
dv[mi][ni].x = fmha::float2_to_half2(tmp00, tmp01);
dv[mi][ni].y = fmha::float2_to_half2(tmp02, tmp03);
dv[mi][ni].z = fmha::float2_to_half2(tmp10, tmp11);
dv[mi][ni].w = fmha::float2_to_half2(tmp12, tmp13);
}
}
smem_dv.store(dv);
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out);
Qkv_params dv_params;
dv_params.qkv_ptr = params.dqkv_ptr;
dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);
gmem_dv.store(dv_out);
}
template<typename Kernel_traits, typename Params>
inline __device__ void compute_dq_dk_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dk =
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128);
static_assert(Cta_tile_dk::N == 64);
static_assert(Cta_tile_dk::K == 16);
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dk = fmha::Hmma_tile<Cta_tile_dk>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
// using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
using Gmem_tile_o = fmha::Gmem_tile_dq<Cta_tile_o>;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store dK.
using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle dK.
using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dk>;
static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);
static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);
// The shared memory tile to reload Q transposed.
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dk, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
enum { M = Mma_tile_p::MMAS_M };
enum { N = Mma_tile_p::MMAS_N };
static_assert(M == Mma_tile_o::MMAS_M);
static_assert(N == Mma_tile_o::MMAS_K);
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
Smem_tile_qt smem_qt(&smem_[0], tidx);
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
// Load dP
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
gmem_s.move();
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Make sure the data is in shared memory.
__syncthreads();
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N];
smem_qt.load(frag_qt[0], 0);
typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N];
smem_k.load(frag_k[0], 0);
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
enum { THREADS_PER_ROW = 32 };
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) {
if( loop >= binfo.actual_seqlen )
break;
// Pack dP as Fragment_a
fmha::Fragment_a<fmha::Row> frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 &dst = s_regs[mi][ni];
frag_p[ni][mi].reg(0) = dst.x; // row 0, cols 0,1
frag_p[ni][mi].reg(1) = dst.z; // row 8, cols 0,1
frag_p[ni][mi].reg(2) = dst.y; // row 0, cols 8,9
frag_p[ni][mi].reg(3) = dst.w; // row 8, cols 8,9
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T. dQ = dP x dK
#pragma unroll
for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_o::MMAS_K;
fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);
}
// Store dP to smem for transpose
smem_s.store(s_regs);
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
// Load next part of S
gmem_s.load(s_regs, mask);
gmem_s.move();
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_o.store(out, ii);
}
// Move to the next part of the output.
gmem_o.move();
typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];
smem_s.load(frag_s);
#pragma unroll
for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dk::MMAS_K;
fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Commit the values for Q into shared memory.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
gmem_q.commit(smem_q);
}
// Make sure the data is in shared memory.
__syncthreads();
// Trigger the loads for the values of Q for the next iteration.
smem_qt.load(frag_qt[0], 0);
smem_k.load(frag_k[0], 0);
} // Outer loop over the sequence length.
// Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[0], tidx);
uint4 dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];
#pragma unroll
for( int mi = 0; mi < Mma_tile_dk::MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_dk::MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
float tmp00 = acc_dk[mi][ni].elt(0);
float tmp01 = acc_dk[mi][ni].elt(1);
float tmp02 = acc_dk[mi][ni].elt(4);
float tmp03 = acc_dk[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc_dk[mi][ni].elt(2);
float tmp11 = acc_dk[mi][ni].elt(3);
float tmp12 = acc_dk[mi][ni].elt(6);
float tmp13 = acc_dk[mi][ni].elt(7);
dk[mi][ni].x = fmha::float2_to_half2(tmp00, tmp01);
dk[mi][ni].y = fmha::float2_to_half2(tmp02, tmp03);
dk[mi][ni].z = fmha::float2_to_half2(tmp10, tmp11);
dk[mi][ni].w = fmha::float2_to_half2(tmp12, tmp13);
}
}
smem_dk.store(dk);
__syncthreads();
uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out);
Qkv_params dk_params;
dk_params.qkv_ptr = params.dqkv_ptr;
dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);
gmem_dk.store(dk_out);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 128, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_128_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
extern "C" __global__ void fmha_fprop_fp16_128_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
}
void run_fmha_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_128_64_sm80_train_kernel : &fmha_fprop_fp16_128_64_sm80_predict_kernel;
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 256, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_256_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
extern "C" __global__ void fmha_fprop_fp16_256_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
}
void run_fmha_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_256_64_sm80_train_kernel : &fmha_fprop_fp16_256_64_sm80_predict_kernel;
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN_reload_v.h"
using Kernel_traits = FMHA_kernel_traits< 384, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_384_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
extern "C" __global__ void fmha_fprop_fp16_384_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
}
void run_fmha_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_384_64_sm80_train_kernel : &fmha_fprop_fp16_384_64_sm80_predict_kernel;
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_v + smem_size_o + smem_size_softmax;
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits< 512, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_train_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, true>(params);
}
extern "C" __global__ void fmha_fprop_fp16_512_64_sm80_predict_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN<Kernel_traits, false>(params);
}
void run_fmha_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, bool is_training, cudaStream_t stream) {
auto kernel = is_training ? &fmha_fprop_fp16_512_64_sm80_train_kernel : &fmha_fprop_fp16_512_64_sm80_predict_kernel;
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size = smem_size_q + std::max(smem_size_v, smem_size_o + smem_size_softmax);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Kernel_traits, bool Is_training, typename Params> inline __device__ void device_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));
Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = nullptr;
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE];
} else {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Trigger the loads for K.
gmem_v.load(smem_v);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Commit the data for V to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_v.commit(smem_v);
}
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];
smem_q.load(frag_q[0], 0);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
smem_k.load(frag_k[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_v.commit(smem_v);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
// Create the object to do the softmax.
using Softmax = fmha::Softmax< Cta_tile_p, Kernel_traits>;
Softmax softmax(params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
enum { THREADS_PER_ROW = 32 };
// Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) {
if( loop >= binfo.actual_seqlen )
break;
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Store the P matrix.
#if defined(STORE_P)
gmem_p.store(acc_p);
#endif
// Load the mask for that iteration.
mask.load(outer);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && loop == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Max_>(p_max);
// Make sure we are done reading shared memory.
__syncthreads();
// Compute the exponential value.
softmax.apply_exp(p_max);
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Sum_>(p_sum);
// Finalize softmax on the accumulators of P^T.
softmax.scale(p_sum);
if( Is_training ) {
auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
float4 tmp = uniform4(ph());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from
// pre-existing zeros
softmax.elt_[2 * mi + ii][4 * ni + 0] =
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);
softmax.elt_[2 * mi + ii][4 * ni + 1] =
encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);
softmax.elt_[2 * mi + ii][4 * ni + 2] =
encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);
softmax.elt_[2 * mi + ii][4 * ni + 3] =
encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
}
// Trigger the load for the next Q values.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
using Frag_p = fmha::Fragment_a< fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
softmax.pack(frag_p);
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {
//"Apply" the dropout.
frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);
frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_o.store(out, ii);
}
// Move to the next part of the output.
gmem_o.move();
// Commit the values for Q into shared memory.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
gmem_q.commit(smem_q);
}
// Make sure the data is in shared memory.
__syncthreads();
// Trigger the loads for the values of Q for the next iteration.
smem_q.load(frag_q[0], 0);
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Kernel_traits, bool Is_training, typename Params> inline __device__ void device_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
Mask<Cta_tile_p> mask(params, binfo, tidx);
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), binfo.tidx_global, std::get<1>(seeds));
static_assert(2 * Mma_tile_p::MMAS_M * 4 * Mma_tile_p::MMAS_N <= 64);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[0], tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = nullptr;
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
smem_v_ = &smem_[0];
} else {
smem_v_ = &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE];
}
static_assert(Kernel_traits::SHARE_SMEM_FOR_K_AND_V);
static_assert(Smem_tile_k::BYTES_PER_TILE == Smem_tile_v::BYTES_PER_TILE);
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[Smem_tile_v::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_v::BYTES_PER_TILE], tidx);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Trigger the loads for K.
gmem_v.load(smem_v);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Commit the data for V to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_v.commit(smem_v);
}
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[1][Mma_tile_p::MMAS_M];
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
smem_k.load(frag_k[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_v.commit(smem_v);
}
enum { BITS_PER_ELT_S = sizeof(typename fmha::A_type) * 8 };
Gmem_tile_s gmem_s(params.s_ptr, params, tidx);
// Create the object to do the softmax.
using Softmax = fmha::Softmax< Cta_tile_p, Kernel_traits>;
Softmax softmax(params, &smem_[Smem_tile_v::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], bidb, tidx);
constexpr int SMEM_BYTES_SOFTMAX = Softmax::ELEMENTS * sizeof(float);
static_assert(SMEM_BYTES_SOFTMAX == Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float));
enum { THREADS_PER_ROW = 32 };
const float pinv = 1.f / params.p_dropout;
// Load over the entire sequence length.
for( int loop = 0, outer = 0; loop < Cta_tile_p::N; loop += Cta_tile_p::M, outer++ ) {
if( loop >= binfo.actual_seqlen )
break;
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[0], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[0], frag_k[ki]);
}
// Load the mask for that iteration.
mask.load(outer);
// Convert from the accumulator typ e to FP32 for Softmax.
softmax.unpack(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
static_assert(2 * Mma_tile_p::MMAS_M * 4 * Mma_tile_p::MMAS_N <= 64);
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Max_>(p_max);
// Make sure we are done reading shared memory.
__syncthreads();
// Compute the exponential value.
softmax.apply_exp(p_max);
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
softmax.template reduce<fmha::Sum_>(p_sum);
// Finalize softmax on the accumulators of P^T.
softmax.scale(p_sum);
__syncthreads();
if( Is_training ) {
auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; };
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
float4 tmp = uniform4(ph());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from
// pre-existing zeros
softmax.elt_[2 * mi + ii][4 * ni + 0] =
encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]);
softmax.elt_[2 * mi + ii][4 * ni + 1] =
encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]);
softmax.elt_[2 * mi + ii][4 * ni + 2] =
encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]);
softmax.elt_[2 * mi + ii][4 * ni + 3] =
encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]);
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
}
// Trigger the load for the next Q values.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
typename Smem_tile_v::Fragment frag_v[1][Mma_tile_o::MMAS_N];
using Frag_p = fmha::Fragment_a< fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
softmax.pack(frag_p);
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) {
//"Apply" the dropout.
frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout);
frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of V values.
smem_v.load(frag_v[0], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_o, frag_p[ki], frag_v[0]);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Always sync after last iter: shared smem_q and smem_o!
__syncthreads();
// Output the values.
gmem_o.store(out, ii);
}
// same smem as o
// Move to the next part of the output.
gmem_o.move();
// Commit the values for Q into shared memory.
if( loop + Cta_tile_p::M < Cta_tile_p::N ) {
gmem_q.commit(smem_q);
}
// Make sure the data is in shared memory.
__syncthreads();
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment