Commit 4f285b35 authored by Tri Dao's avatar Tri Dao
Browse files

FlashAttention-2 release

parent 6d48e14a
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // For dropout there might be a lot of register spilling?
// // These two are very slow due to register spilling
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
// // This one is slightly slower
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
// });
// }
template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// // Using block size (64 x 256) is 27% slower for seqlen=2k
// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
// }
// }
template<>
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
// // This 3rd one is good for H100, and A100, A6000
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
// // These two are always slower
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
// });
// }
template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
}
\ No newline at end of file
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cmath>
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "block_info.h"
#include "kernel_traits.h"
#include "utils.h"
#include "softmax.h"
#include "philox.cuh"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_M,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
constexpr int MMAStride_M = MMA_M * AtomShape_M;
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
Stride<_1, Int<MMAStride_M>> >{},
make_layout(size<2>(TileShape_MNK{})));
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_M,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
constexpr int MMAStride_M = MMA_M * AtomShape_M;
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
Stride<_1, Int<MMAStride_M>> >{},
// TODO: Shouldn't this be size<1>?
make_layout(size<2>(TileShape_MNK{})));
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
Tensor2 &acc_o, float softmax_scale_log2) {
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
flash::reduce_sum(scores, scores_sum);
} else {
Tensor scores_max_prev = make_fragment_like(scores_max);
copy(scores_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
#pragma unroll
for (int mi = 0; mi < size(scores_max); ++mi) {
float scores_max_cur = !Check_inf
? scores_max(mi)
: (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
scores_sum(mi) *= scores_scale;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
Tensor scores_sum_cur = make_fragment_like(scores_sum);
flash::reduce_sum(scores, scores_sum_cur);
#pragma unroll
for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); }
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void write_softmax_to_gmem(
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P
) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout();
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal) {
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
}
// We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// We move K and V to the last block.
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.seqlen_k_rounded, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx);
auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
Tensor tPgP = gmem_thr_copy_P.partition_D(gP);
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
//
// Copy Atom retiling
//
auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// TODO: this might need to change if we change the mma instruction in SM70
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
Tensor scores_sum = make_fragment_like(scores_max);
//
// PREDICATES
//
// // Allocate predicate tensors for m and n
// Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
// Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
// Construct identity layout for sQ and sK
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
// if (cute::thread0()) {
// print(tScQ.layout()); printf("\n");
// for (int i = 0; i < size(tScQ); ++i) {
// printf("%d ", get<0>(tScQ(i)));
// }
// printf("\n");
// for (int i = 0; i < size(tScQ); ++i) {
// printf("%d ", get<1>(tScQ(i)));
// }
// printf("\n");
// }
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
// Set predicates for k bounds
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
#pragma unroll
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
}
// Prologue
Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
// // Copy rmem to smem
// // copy(tQrQ, tQsQ);
// flash::cp_async_wait<0>();
// __syncthreads();
// // if (cute::thread(1, 0)) { print(tQsQ); }
// // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
// // if (cute::thread0()) { print(sQNoSwizzle); }
if (Kernel_traits::Share_Q_K_smem) {
flash::cp_async_wait<0>();
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads();
}
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads();
if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
flash::cp_async_wait<1>();
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
}
auto seeds = at::cuda::philox::unpack(params.philox_args);
unsigned long long seed = std::get<0>(seeds);
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
clear(acc_o);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1;
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
// Advance gV
if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else {
// Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread0()) { print(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if (!Is_causal) {
if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else {
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
// static_assert(decltype(size<0>(taccScS))::value == 4);
// // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices.
// Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
// Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout()));
// flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM);
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
}
flash::cp_async_wait<0>();
__syncthreads();
if (n_block > 0) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
}
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
uint32_t block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps
);
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
}
if (Is_dropout) {
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps);
}
// if (cute::thread0()) { print(tOrP); }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
if (n_masking_steps > 1 && n_block <= 0) {
--n_block;
break;
}
}
// These are the iterations where we don't need masking on S
for (; n_block >= 0; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
// Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
);
flash::cp_async_wait<0>();
__syncthreads();
if (n_block > 0) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
}
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
uint32_t block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps
);
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
}
if (Is_dropout) {
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps);
}
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
}
// Epilogue
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
Tensor lse = make_fragment_like(scores_sum);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = scores_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
// if (cute::thread0()) { print(acc_o_rowcol); }
// Convert acc_o from fp32 to fp16/bf16
Tensor rO = flash::convert_type<Element>(acc_o);
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// sO has the same size as sQ, so we don't need to sync here.
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
copy(smem_thr_copy_O, taccOrO, taccOsO);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
__syncthreads();
Tensor tOrO = make_tensor<Element>(shape(tOgO));
copy(gmem_thr_copy_O, tOsO, tOrO);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
static_assert(decltype(size<0>(taccOcO))::value == 4);
// Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
if (get<1>(taccOcO_row(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
}
}
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
// them to have the same number of threads or have to traverse the attention matrix
// in the same order.
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
// (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "static_switch.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
constexpr size_t smem_size = Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
// for cu_seqlens_q as well.
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
// Will only return softmax if dropout, to reduce compilation time.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
}
template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 32;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 64;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
template<typename T>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 96;
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if (is_sm8x) {
if constexpr(!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// These two are always slower
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 128;
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
if (is_sm8x) {
if constexpr(!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 1st ones are good for H100, A100
// 2nd one is good for A6000 bc we get slightly better occupancy
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
template<typename T>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 160;
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 64 with 8 warps is the fastest for non-causal.
if (is_sm8x) {
if constexpr(!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 192;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 224;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
// printf("max_smem_per_block = %d\n", max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
// If we have N = 32, there are only 1024 elements to load at once, where each load
// is 8 elements. This means we can only use 128 threads and not 256 threads.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 256;
int device;
cudaGetDevice(&device);
int max_smem_per_sm, max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// 64 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 96 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/UnpackRaw.cuh>
#include <fmha_utils.h>
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
// size_t qkv_stride_in_elts;
// size_t qkv_stride_in_bytes;
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
uint32_t q_row_stride_in_elts;
uint32_t k_row_stride_in_elts;
uint32_t v_row_stride_in_elts;
uint32_t q_head_stride_in_elts;
uint32_t k_head_stride_in_elts;
uint32_t v_head_stride_in_elts;
// The number of heads.
int h;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FMHA_fprop_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
// The stride between rows of O.
// size_t o_stride_in_elts;
// size_t o_stride_in_bytes;
uint32_t o_row_stride_in_elts;
uint32_t o_head_stride_in_elts;
uint32_t o_tmp_row_stride_in_elts;
uint32_t o_tmp_head_stride_in_elts;
// The pointer to the O_tmp matrix, which holds O intermediate value during
// the loop;
void *__restrict__ o_tmp_ptr;
// The pointer to the S matrix.
void * __restrict__ s_ptr;
// The stride between rows of the S matrix.
// int64_t s_stride_in_bytes;
uint32_t s_stride_in_bytes;
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, d;
// The scaling factors for the kernel.
float scale_bmm1f;
uint32_t scale_bmm1;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int *__restrict__ blockmask;
// The dropout probability (probability of keeping an activation).
float p_dropout;
uint32_t p_dropout_in_uint;
uint16_t p_dropout_in_uint16_t;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_bmm1_rp_dropout;
// Scale factor of 1 / (1 - p_dropout), in half2.
uint32_t scale_dropout;
// Random state.
at::PhiloxCudaState philox_args;
// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;
bool is_bf16;
bool is_causal;
int num_splits; // How many SMs per attention matrix.
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FMHA_dgrad_params : public FMHA_fprop_params {
// The dQKV matrices.
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q dimension
// void *__restrict__ dk_accum_ptr;
// void *__restrict__ dv_accum_ptr;
// The stride between rows of the dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
uint32_t dq_row_stride_in_elts;
uint32_t dk_row_stride_in_elts;
uint32_t dv_row_stride_in_elts;
uint32_t dq_head_stride_in_elts;
uint32_t dk_head_stride_in_elts;
uint32_t dv_head_stride_in_elts;
// The dO matrix. We assume it is contiguous.
void * __restrict__ do_ptr;
// The pointer to the softmax d sum.
void * __restrict__ dsoftmax_sum;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_params>
struct Launch_params{
Launch_params(cudaDeviceProp * props_,
cudaStream_t stream_,
bool is_dropout_,
bool return_softmax_)
: elts_per_thread(0)
, props(props_)
, stream(stream_)
, is_dropout(is_dropout_)
, return_softmax(return_softmax_) {
}
size_t elts_per_thread;
cudaDeviceProp * props;
cudaStream_t stream;
bool is_dropout;
bool return_softmax;
Kernel_params params;
int num_full_heads;
int num_main_groups;
int heads_last_wave;
int main_steps;
int rest_steps;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream);
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <fmha/utils.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/layout/layout.h"
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ >
struct Fragment_base_ {
// The data type.
using Data_type = Data_type_;
// default input type
using Input_type_ = Data_type_;
// Does it store the array of elements.
static constexpr bool HAS_ELTS = BITS_PER_ELT_ >= 8;
// The number of elements.
static constexpr int NUM_ELTS = NUM_ELTS_;
// The size of element in bits.
static constexpr int BITS_PER_ELT = BITS_PER_ELT_;
// The size of byte of a single register.
static constexpr int BYTES_PER_REG = 4;
// The size in bits.
static constexpr int BITS_PER_REG = BYTES_PER_REG * 8;
// The number of registers needed to store the fragment.
static constexpr int NUM_REGS = DivUpConstexpr(NUM_ELTS * BITS_PER_ELT, BITS_PER_REG);
// The size in bytes (as returned by sizeof(Fragment_base<>).
static constexpr int SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG;
// The alignment.
static constexpr int ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : MinConstexpr(NUM_REGS * BYTES_PER_REG, 16);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The type of the elements.
typename Data_type_,
// The number of elements.
int NUM_ELTS_,
// The alignment if you want to force a value -- use 0 otherwise.
int ALIGNMENT_ = 0,
// The base class.
typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_>
>
struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
// The size of a load/store.
static constexpr int BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t);
// Clear the fragment. Using PTX in that code seems to produce better SASS...
inline __device__ void clear() {
#pragma unroll
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : );
}
}
// Immutable access to a register.
inline __device__ const uint32_t& reg(int ii) const {
return this->regs_[ii];
}
// Mutable access to a register.
inline __device__ uint32_t& reg(int ii) {
return this->regs_[ii];
}
uint32_t regs_[Base_::NUM_REGS];
// Immutable access to the elements.
inline __device__ const Data_type_& elt(int ii) const {
return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];
}
// Mutable access to the elements.
inline __device__ Data_type_& elt(int ii) {
return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];
}
// Immutable access to the elements with a cast.
template< typename Cast_type >
inline __device__ const Cast_type& elt_as(int ii) const {
return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];
}
// Mutable access to the elements.
template< typename Cast_type >
inline __device__ Cast_type& elt_as(int ii) {
return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];
}
// Add another fragment.
inline __device__ void add(const Fragment &other) {
// TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS?
// Also are we doing int addition or __half2 addition?
#pragma unroll
for( int ii = 0; ii < NUM_ELTS_; ++ii ) {
this->elt(ii) += other.elt(ii);
}
}
// Multiply by another fragment.
inline __device__ void hmul(const Fragment &other) {
#pragma unroll
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii));
}
}
template <typename elem_type>
inline __device__ void hrelu_() {
#pragma unroll
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
this->reg(ii) = fmha::hrelu2<elem_type>(this->reg(ii));
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Layout >
struct Fragment_a : public Fragment<uint16_t, 8> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Layout >
struct Fragment_b : public Fragment<uint16_t, 8> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fragment_accumulator : public Fragment<float, 8> {
// The base class.
using Base = Fragment<float, 8>;
// Add two fragments.
template< typename Other_fragment_ >
inline __device__ void add(const Other_fragment_ &other) {
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
this->elt(ii) = this->elt(ii) + other.elt(ii);
}
}
inline __device__ void mul_(const float other) {
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
this->elt(ii) *= other;
}
}
// Do the HMMA.
template< typename Layout_a, typename Layout_b >
inline __device__ void mma(const Fragment_a<Layout_a> &a,
const Fragment_b<Layout_b> &b) {
asm volatile( \
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
" {%0, %1, %2, %3}, \n" \
" {%4, %5, %6, %7}, \n" \
" {%8, %9}, \n" \
" {%0, %1, %2, %3}; \n" \
: "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3))
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
, "r"(b.reg(0)), "r"(b.reg(1)));
asm volatile( \
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
" {%0, %1, %2, %3}, \n" \
" {%4, %5, %6, %7}, \n" \
" {%8, %9}, \n" \
" {%0, %1, %2, %3}; \n" \
: "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7))
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
, "r"(b.reg(2)), "r"(b.reg(3)));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Fragment, int M, int N >
inline __device__ void clear(Fragment (&frag)[M][N]) {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < N; ++ni ) {
frag[mi][ni].clear();
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Accumulator_type, int WARPS_K >
struct Clear_accumulator {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int WARPS_K >
struct Clear_accumulator<float, WARPS_K> {
template< typename Acc, int M, int N >
static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {
fmha::clear(acc);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B, int M, int N>
inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < N; ++ni ) {
acc[mi][ni].mma(a[mi], b[ni]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////
/// Statically maps half types => cutlass data types
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Type_>
struct HalfTypeToCutlassType { using Type = Type_; };
/// Statically maps __half => cutlass::half_t
template <> struct HalfTypeToCutlassType<__half> {
using Type = cutlass::half_t;
};
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
template <> struct HalfTypeToCutlassType<__nv_bfloat16> {
using Type = cutlass::bfloat16_t;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename elem_type, typename Acc, typename A, typename B, int M, int N>
inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
#else
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
// TD [2022-06-02] We don't support Volta (SM70) yet.
assert(0);
#endif
using Element = typename HalfTypeToCutlassType<elem_type>::Type;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type;
constexpr int kIters = Shape::kK / InstructionShape::kK;
// using FragmentA = typename WarpMma::FragmentA;
// using FragmentB = typename WarpMma::FragmentB;
using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA;
using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB;
using FragmentC = typename WarpMma::FragmentC;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) {
// printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements);
// printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements);
// printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements);
// printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements);
// printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements);
// printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements);
// }
// static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS);
// static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS);
static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS);
static_assert(FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == b[0].NUM_REGS);
static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS);
// const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a);
// const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b);
FragmentC c_cl = reinterpret_cast<FragmentC (&)>(acc);
FragmentA a_cl[kIters][M];
FragmentA b_cl[kIters][N];
constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2;
#pragma unroll
for (int iter = 0; iter < kIters; iter++) {
#pragma unroll
for (int mi = 0; mi < M; mi++) {
uint32_t *a_ptr = a_cl[iter][mi].raw_data();
#pragma unroll
for (int ki = 0; ki < kRegs; ki++) {
a_ptr[ki] = a[mi].regs_[iter * kRegs + ki];
}
}
}
#pragma unroll
for (int iter = 0; iter < kIters; iter++) {
#pragma unroll
for (int ni = 0; ni < N; ni++) {
uint32_t *b_ptr = b_cl[iter][ni].raw_data();
#pragma unroll
for (int ki = 0; ki < kRegs; ki++) {
// b_ptr[ki] = b[ni].regs_[iter * kRegs + ki];
// TD [2022-06-02] For some reason the order for frag_b is different.
b_ptr[ki] = b[ni].regs_[InstructionShape::kK == 16 ? iter * kRegs + ki : ki * kRegs + iter];
}
}
}
WarpMma mma_op;
// mma_op(c_cl, a_cl, b_cl, c_cl);
#pragma unroll
for (int iter = 0; iter < kIters; iter++) {
mma_op(c_cl, reinterpret_cast<const typename WarpMma::FragmentA (&)>(a_cl[iter]),
reinterpret_cast<const typename WarpMma::FragmentB (&)>(b_cl[iter]), c_cl);
}
// The modified c_cl is not copied back into acc, idk why
#pragma unroll
for (int mi = 0; mi < M; mi++) {
#pragma unroll
for (int ni = 0; ni < N; ni++) {
#pragma unroll
for (int i =0; i < 8; i++) {
acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i];
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The number of rows in the CTA tile.
int M_,
// The number of cols in the CTA tile.
int N_,
// The number of elements in the the K dimension of the GEMM loop.
int K_,
// The number of rows of warps.
int WARPS_M_,
// The number of cols of warps.
int WARPS_N_,
// The number of warps in the K dimension of the GEMM loop.
int WARPS_K_>
struct Cta_tile_ {
static constexpr int M = M_, N = N_, K = K_;
// The number of warps.
static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_;
// The number of warps per CTA.
static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K;
// The number of threads per warp.
static constexpr int THREADS_PER_WARP = 32;
// The number of threads per CTA.
static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Hmma_tile {
// The number of elements computed with a single warp-MMA.
static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16;
// The number of elements computed with a single CTA-MMA.
static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K;
// The number of MMAs needed to compute the GEMM.
static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA),
MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA),
MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA);
// // The number of elements computed per warp.
// static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA,
// N_PER_WARP = MMAS_N * N_PER_MMA,
// K_PER_WARP = MMAS_K * K_PER_MMA;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using A_type = uint16_t;
using B_type = uint16_t;
using C_type = uint16_t;
using Accumulator_type = float;
using Epilogue_type = float;
constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;
constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;
constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>
using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile_>
using Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M,
Cta_tile_::N,
Next_power_of_two<Cta_tile_::K>::VALUE,
Cta_tile_::WARPS_M,
Cta_tile_::WARPS_N,
Cta_tile_::WARPS_K>;
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <fmha/utils.h>
namespace fmha {
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile_,
// The number of bits per element.
int BITS_PER_ELEMENT,
// The number of rows of Q, K or V loaded by this tile.
int ROWS_,
// The number of columns.
int COLS,
int BYTES_PER_LDGS_ = 16
>
struct Gmem_tile_qkv {
using Cta_tile = Cta_tile_;
static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8;
// The size of each LDG.
static constexpr int BYTES_PER_LDG = BYTES_PER_LDGS_;
// The size of a row in bytes.
static constexpr int BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8;
// The number of threads to load a "row" of the matrix.
static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG;
static constexpr int ROWS = ROWS_;
// The number of "rows" loaded per LDG.
static constexpr int ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
// The number of LDGs needed to load a chunk of the Q matrix.
static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG);
// Ctor.
template< typename BInfo >
inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts,
const uint32_t head_stride_in_elts, const int headdim,
const BInfo &binfo, const int tidx, bool use_seqlen_q)
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen(use_seqlen_q ? binfo.actual_seqlen_q : binfo.actual_seqlen_k)
, ptr(reinterpret_cast<char *>(ptr_))
, tidx_(tidx)
, col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_LDG / BYTES_PER_ELEMENT) < headdim) {
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % THREADS_PER_ROW;
// Store the row as we need it to disable the loads.
// TD [2022-04-16]: To minimize registers, we'll recompute row_ instead of storing it
// row_ = row;
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
uint32_t row_offset = (uint32_t)(((use_seqlen_q ? binfo.sum_s_q : binfo.sum_s_k) + row) * row_stride_in_bytes);
// Add the block index.
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
// Assemble the final pointer.
ptr += row_offset + col * BYTES_PER_LDG;
}
// Store data to shared memory.
template< typename Smem_tile >
inline __device__ void commit(Smem_tile &smem_tile) {
smem_tile.store(fetch_);
}
inline __device__ void load() {
int row_ = tidx_ / THREADS_PER_ROW;
const void *ptrs[LDGS];
uint32_t preds[LDGS];
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
// ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
preds[ii] = col_predicate && ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));
fetch_[ii] = make_uint4(0, 0, 0, 0);
}
// not packing predicates removes restrictions (e.g. FP16 384, 4 warps)
Ldg_functor<uint4, LDGS> fct(fetch_, ptrs);
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
fct.load(ii, preds[ii]);
}
}
// Store data to memory.
inline __device__ void store(const uint4 (&data)[LDGS]) {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
fmha::stg(ptr_, data[ii]);
}
}
}
inline __device__ void move(const int steps = 1) {
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr += (uint32_t)ROWS * row_stride_in_bytes * steps;
actual_seqlen -= ROWS * steps;
}
// The stride between rows for the QKV matrice.
// int64_t row_stride_in_bytes;
const uint32_t row_stride_in_bytes;
// The pointer.
char *ptr;
// The fetch registers.
uint4 fetch_[LDGS];
// Keep track of the row the thread is processing as we move the tile.
// int row_;
const int tidx_;
// The length of the sequence loaded by that memory tile.
int actual_seqlen;
const bool col_predicate;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Cta_tile,
int BYTES_PER_ELEMENT = 2
>
struct Gmem_tile_o {
static_assert(BYTES_PER_ELEMENT == 2 || BYTES_PER_ELEMENT == 4);
// The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The size of each element.
// static constexpr int BYTES_PER_ELEMENT = 2;
// The size of each STG.
static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 4;
static constexpr int COLS = Cta_tile::N;
// The size of a row in bytes.
static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT;
// The number of threads to store a "row" of the matrix.
static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG;
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
static constexpr int ROWS = Cta_tile::M;
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA;
// The number of outter loop for the stores.
static constexpr int LOOPS = ROWS / ROWS_PER_LOOP;
// The number of "rows" stored per STG.
static constexpr int ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
// Do we have to guard against partial writes/reads.
static constexpr bool HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0;
// The number of STGs needed to store a chunk of the Q matrix.
static constexpr int STGS_PER_LOOP = DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_STG);
// The number of STGs needed to store a chunk of the Q matrix in total.
static constexpr int STGS = STGS_PER_LOOP * LOOPS;
// Ctor.
template<typename BInfo>
// inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx)
inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts,
const uint32_t head_stride_in_elts, const int headdim,
const BInfo &binfo, const int tidx)
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen_q(binfo.actual_seqlen_q)
, ptr_(reinterpret_cast<char *>(ptr))
, tidx_(tidx)
, col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_STG / BYTES_PER_ELEMENT) < headdim) {
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % THREADS_PER_ROW;
// Store the row as we need it to disable loads.
// row_ = row;
// The row offset in the batched GEMM.
// int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)((binfo.sum_s_q + row) * row_stride_in_bytes);
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
// Assemble the final pointer.
ptr_ += row_offset + col * BYTES_PER_STG;
// Is that thread active on the last STG?
if( HAS_INCOMPLETE_STG ) {
is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M;
}
}
// Store data to global memory.
template<typename elem_type=__half>
inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
break;
}
if (BYTES_PER_ELEMENT == 4) {
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, src[ii]);
}
} else if (BYTES_PER_ELEMENT == 2) {
float x = reinterpret_cast<const float &>(src[ii].x);
float y = reinterpret_cast<const float &>(src[ii].y);
float z = reinterpret_cast<const float &>(src[ii].z);
float w = reinterpret_cast<const float &>(src[ii].w);
uint2 out = fmha::float4_pack<elem_type>(x, y, z, w);
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out);
}
}
}
}
// Store data to global memory with atomicAdd.
inline __device__ void atomic_add(const uint4 (&src)[STGS_PER_LOOP], int mi) {
static_assert(BYTES_PER_ELEMENT == 4); // Only do atomic add on floats
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
break;
}
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
float *ptr_ = reinterpret_cast<float *>(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
atomicAdd(ptr_ + jj, reinterpret_cast<const float(&)[4]>(src[ii])[jj]);
}
}
}
}
// Load data from global memory.
inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) {
static_assert(BYTES_PER_ELEMENT == 4);
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
break;
}
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
}
}
}
inline __device__ void move(const int steps = 1) {
// row_ += ROWS * steps;
// ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps;
actual_seqlen_q -= ROWS * steps;
}
// The stride between rows for the QKV matrice.
// int64_t row_stride_in_bytes;
const uint32_t row_stride_in_bytes;
// The pointer.
char *ptr_;
// Is the thread active for the last STG?
int is_active_for_last_stg_;
// The length of the sequence loaded by that memory tile.
int actual_seqlen_q;
const int tidx_;
const bool col_predicate;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, int BYTES_PER_ELEMENT >
struct Gmem_tile_mma_sd {
// The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// Each STG stores 8 elements.
static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 8;
// The number of MMAs in the M dimension.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
// The number of MMAs in the N dimension.
static constexpr int MMAS_N = Mma_tile::MMAS_N;
// The number of rows computed per MMA per thread block.
static constexpr int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA;
// The number of cols computed per MMA per thread block.
static constexpr int N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA;
// The number of threads per block.
static constexpr int THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA;
// The size of each row in bytes. I.e. how many bytes are stored per STG.
static constexpr int BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG;
// The distance between elements stored per loop (in bytes).
static constexpr int LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW;
// The type of elements stored per STG.
using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;
// Ctor.
template<typename Params>
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int bidb, const int bidh, const int tidx)
: ptr_(static_cast<char *>(ptr)) {
// The block index.
// size_t bidx = bidb * params.h + bidh;
uint32_t bidx = bidb * params.h + bidh;
// The distance between two blocks (in bytes).
// const size_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
const uint32_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
// Set store location for each thread at the beginning of the loop
ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG;
}
// Store to global memory.
inline __device__ void store(const Type &data, const int mi, const int ni) {
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
fmha::stg(ptr_ + offset, data);
}
// Load from global memory.
inline __device__ void load(Type &data, const int mi, const int ni) {
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
fmha::ldg(data, ptr_ + offset);
}
// Move to the next tile.
inline __device__ void move(const int steps = 1) {
ptr_ += LOOP_STRIDE_BYTES * steps;
}
// The pointer in global memory.
char *ptr_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >
struct Gmem_tile_mma_s : public Base {
// The number of mmas in the vertical dimension.
static constexpr int M = Base::MMAS_M;
// The number of mmas in the horizontal dimension.
static constexpr int N = Base::MMAS_N;
// The type of the vectors stored by each STG.
using Type = typename Base::Type;
// Ctor.
template< typename Params, typename Block_info >
inline __device__ Gmem_tile_mma_s(const Params &params, const Block_info& binfo, const int tidx)
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
}
// Store to global memory.
template<typename Mask, typename Fragment>
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 dst;
dst.x = frag[ni][mi].reg(0);
dst.y = frag[ni][mi].reg(2);
dst.z = frag[ni][mi].reg(1);
dst.w = frag[ni][mi].reg(3);
if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni);
}
}
}
}
// Load from global memory.
template<typename Mask>
inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
regs[mi][ni] = make_uint4(0, 0, 0, 0);
if( mask.any_valid(mi, ni) ) {
Base::load(regs[mi][ni], mi, ni);
}
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile
>
struct Gmem_summary_stats {
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
// The size of each element.
static constexpr int BYTES_PER_ELEMENT = 4;
static constexpr int BYTES_PER_MMA = (Cta_tile::THREADS_PER_WARP / 4) * 2 * BYTES_PER_ELEMENT;
static constexpr int ROWS = Cta_tile::M;
// Ctor.
template<typename Params>
inline __device__ Gmem_summary_stats(void *ptr, const Params &params, const int tidx)
: ptr_(reinterpret_cast<char *>(ptr)), tidx_(tidx) {
// The block index for the batch.
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.y;
// The block index.
// size_t bidx = bidb * params.h + bidh;
uint32_t bidx = bidb * params.h + bidh;
// Extract the position in the warp.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// The distance between two blocks (in bytes).
// size_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
uint32_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
// Set store location for each thread at the beginning of the loop
ptr_row_ = ptr_ + bidx * block_stride_bytes;
ptr_ += bidx * block_stride_bytes + (lane / 4) * BYTES_PER_ELEMENT;
}
// Store data to global memory.
inline __device__ void store(const uint32_t (&data)[MMAS_M * 2]) {
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
if ((warp == 0) && (lane % 4 == 0)) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT, data[mi * 2 + 0]);
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT, data[mi * 2 + 1]);
}
}
}
// Store data to global memory.
inline __device__ void store_row(const uint32_t (&data)[MMAS_M], const int row) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::stg(ptr_row_ + mi * BYTES_PER_MMA + row * BYTES_PER_ELEMENT, data[mi]);
}
}
// Load from global memory.
inline __device__ void load(uint32_t (&data)[MMAS_M * 2]) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::ldg(data[mi * 2 + 0], ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
fmha::ldg(data[mi * 2 + 1], ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
}
}
// Load from global memory.
inline __device__ void load_next(uint32_t (&data)[MMAS_M * 2], int move_steps=1) {
char *ptr_next = ptr_ + move_steps * ROWS * BYTES_PER_ELEMENT;
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::ldg(data[mi * 2 + 0], ptr_next + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
fmha::ldg(data[mi * 2 + 1], ptr_next + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
}
}
// Store data to global memory.
template <int N>
inline __device__ void load_row(uint32_t (&data)[N], const int row[N]) {
#pragma unroll
for (int ni = 0; ni < N; ++ni) {
fmha::ldg(data[ni], ptr_row_ + row[ni] * BYTES_PER_ELEMENT);
}
}
// Move the pointer to the next location.
inline __device__ void move() {
ptr_ += ROWS * BYTES_PER_ELEMENT;
ptr_row_ += ROWS * BYTES_PER_ELEMENT;
}
// Move the pointer to the next location.
inline __device__ void move(const int steps) {
ptr_ += ROWS * BYTES_PER_ELEMENT * steps;
ptr_row_ += ROWS * BYTES_PER_ELEMENT * steps;
}
// The pointer.
char *ptr_;
char *ptr_row_;
const int tidx_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include <cuda_fp16.h>
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u, typename elem_type_=__half>
struct FMHA_kernel_traits {
// The CTA description for the 1st GEMM.
using Cta_tile_p = fmha::Cta_tile_extd<STEP, S, D, WARPS_M, WARPS_N, 1>;
// The CTA description for the 2nd GEMM.
using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;
// Do we use one buffer for K and V.
static constexpr bool SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u;
// Do we keep K in registers.
static constexpr bool K_IN_REGS = (FLAGS & 0x10u) == 0u;
// Do we keep V in registers.
static constexpr bool V_IN_REGS = (FLAGS & 0x100u) == 0u;
// The global memory tile to load Q.
using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
// The shared memory tile to swizzle Q.
// using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;
using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The global memory tile to load K.
using Gmem_tile_k = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_B, S, D>;
// The shared memory tile to swizzle K.
using Smem_tile_k = fmha::Smem_tile_b<Cta_tile_p, fmha::Col>;
// The global memory tile to load V.
using Gmem_tile_v = fmha::Gmem_tile_qkv<Cta_tile_o, fmha::BITS_PER_ELEMENT_B, S, D>;
// The shared memory tile to swizzle V.
using Smem_tile_v = fmha::Smem_tile_v<Cta_tile_o>;
// The global memory tile to store O.
using Gmem_tile_o = fmha::Gmem_tile_o<Cta_tile_o>;
// The shared memory tile for O.
using Smem_tile_o = fmha::Smem_tile_o<Cta_tile_o>;;
// The global memory tile to load/store S.
using Gmem_tile_s = fmha::Gmem_tile_mma_s<Cta_tile_p>;
// The shared memory tile to transpose S.
using Smem_tile_st = fmha::Smem_tile_mma_transposed<Cta_tile_p>;
using Gmem_tile_do = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
// // The global memory tile to store the accumulated dK and dV
// // Hack: we set BYTES_PER_LDGS=32 to emulate the access pattern of dK and dV
// // where there are 16 bits per lements and 16 bytes per load. In reality we won't
// // be issue any load or store of size 32 bytes.
// using Gmem_tile_dkv_accum = fmha::Gmem_tile_qkv<Cta_tile_o, 32, S, D, 32>;
// The global memory tile to store the softmax sum.
using Gmem_softmax_sum = fmha::Gmem_summary_stats<Cta_tile_p>;
// The shared memory tile to store dp sum.
using Smem_dp_sum = fmha::Smem_tile_dp_sum<Gmem_tile_q, 2>;
using elem_type = elem_type_;
// Make sure the number of threads match.
static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, "");
// The number of threads.
static constexpr int THREADS = Cta_tile_p::THREADS_PER_CTA;
// Make sure the number of threads matches both CTAs.
static_assert(THREADS == Cta_tile_o::THREADS_PER_CTA, "");
// The amount of shared memory needed to load Q and K.
static constexpr int BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE;
// The extra amount of shared memory needed to load V.
static constexpr int BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE;
// The amount of shared memory needed for Q, K and V..
static constexpr int BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V;
// The amount of shared memory needed to load Q and store O.
static constexpr int BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE;
// The amount of shared memory needed for Q, K, V and O.
static constexpr int BYTES_PER_SMEM = fmha::MaxConstexpr(BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO);
// Make sure we have enough shared memory.
static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, "");
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace fmha {
template<typename Cta_tile, bool Is_causal=false>
struct Mask {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
template<typename BInfo>
__device__ Mask(const BInfo &binfo, int tidx, const int loop_step_idx_ = 0)
: actual_seqlen_k(binfo.actual_seqlen_k - loop_step_idx_ * Cta_tile::N)
, loop_step_idx(loop_step_idx_) {
const int warp = tidx / Cta_tile::THREADS_PER_WARP;
const int lane = tidx % Cta_tile::THREADS_PER_WARP;
static_assert(Cta_tile::WARPS_K == 1, "");
// find the warp in the Cta tile
const int warp_n = (warp / Cta_tile::WARPS_M);
const int warp_m = (warp % Cta_tile::WARPS_M);
// decompose warp into 8x4 tile
const int quad = lane / 4;
const int tid = (lane % 4) * 2;
row = warp_m * 16 + quad;
col = warp_n * 16 + tid;
}
inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {
// ii and jj iterate over the 2x4 fragment
// const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
const int current_row = row_offset + ii * 8;
const bool col_valid = current_col < actual_seqlen_k;
// const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k;
// bool all_valid = Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 1)) {
// printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid);
// }
return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
// return row_valid && col_valid;
}
//BERT Mask: if upper left is invalid, none are valid
inline __device__ bool any_valid(const int mi, const int ni) const {
return is_valid(mi, ni, 0, 0) || is_valid(mi, ni, 1, 0);
}
inline __device__ void load(const int it) {
row_offset = it * Cta_tile::M + row;
}
int row_offset;
int row;
int col;
const int loop_step_idx;
const int actual_seqlen_k;
};
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "utils.h"
#include <fmha/utils.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The description of the tile computed by this CTA.
typename Cta_tile,
// The number of rows in the 2D shared memory buffer.
int M_,
// The number of cols.
int N_,
// The size in bits of each element.
int BITS_PER_ELEMENT_,
// The number of bytes per STS.
int BYTES_PER_STS_ = 16,
// The number of buffers. (Used in multistage and double buffer cases.)
int BUFFERS_PER_TILE_ = 1,
// Do we enable the fast path for LDS.128 and friends.
int ENABLE_LDS_FAST_PATH_ = 0,
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
int ROWS_PER_XOR_PATTERN_ = 8,
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
int COLS_PER_XOR_PATTERN_ = 1,
// Use or not predicates
bool USE_PREDICATES_ = true
>
struct Smem_tile_without_skews {
// The size in bits of each element.
enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ };
// The size in bytes of a single STS.
enum { BYTES_PER_STS = BYTES_PER_STS_ };
// The number of elements per STS.
enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT };
// To support arbitrary N, we pad some values to a power-of-2.
enum { N_WITH_PADDING = Next_power_of_two<N_>::VALUE };
// The number of bytes per row without packing of rows.
enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 };
// The number of bytes per row -- we want at least 128B per row.
enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE };
// The number of rows in shared memory (two rows may be packed into a single one).
enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW };
// The number of threads per row.
enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS };
// The number of threads per row.
enum { THREADS_PER_ROW = Min<Cta_tile::THREADS_PER_CTA, THREADS_PER_ROW_UNBOUNDED>::VALUE };
// The number of STS per row.
enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS };
// It must be at least one.
static_assert(STS_PER_ROW >= 1, "");
// The number of rows written with a single STS.
enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
// Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)
static_assert(ROWS_PER_STS >= 1, "");
// The number of STS needed to store all rows.
enum { STS_PER_COL = Div_up<ROWS, ROWS_PER_STS>::VALUE };
// The number of STS in total.
enum { STS = STS_PER_COL * STS_PER_ROW };
// TD [2022-06-02] In the case of Q (16 x 64) in the backward pass with 256 threads,
// we only need to store 16 * 64 * 2 = 2KB instead of 4KB.
static constexpr bool PARTIAL_STORE = ROWS_PER_STS > ROWS;
static constexpr int STORING_THREADS = PARTIAL_STORE ? ROWS * THREADS_PER_ROW : Cta_tile::THREADS_PER_CTA;
// The size of one buffer in bytes in shared memory.
// enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };
enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * STORING_THREADS };
// The number of buffers.
enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ };
// The size in bytes of total buffers.
enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE };
// The boundary for smem_read_offset and smem_write_offset increment.
enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER };
// Do we enable the LDS.128 fast path?
enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ };
static_assert(ENABLE_LDS_FAST_PATH == 0);
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ };
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS };
// Use or not predicates
enum { USE_PREDICATES = USE_PREDICATES_ };
// The type of elements that are stored in shared memory by each thread.
using Store_type = typename Uint_from_size_in_bytes<BYTES_PER_STS>::Type;
// Ctor.
inline __device__ Smem_tile_without_skews(void *smem, int tidx)
: smem_(__nvvm_get_smem_pointer(smem)), tidx_(tidx) {
// The row written by a thread. See doc/mma_smem_layout.xlsx.
int smem_write_row = tidx / THREADS_PER_ROW;
// The XOR pattern.
int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN;
// Compute the column and apply the XOR pattern.
int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor;
// The offset.
this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;
// TODO: Why not merge it with the read offset?
// this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);
// this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);
}
// Compute the store pointers.
template< int N >
inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) {
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
// Decompose the STS into row/col.
int row = ii / STS_PER_ROW;
int col = ii % STS_PER_ROW;
// Assemble the offset.
int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW;
// Take the column into account.
if( STS_PER_ROW > 1 ) {
offset += col*THREADS_PER_ROW*BYTES_PER_STS;
}
// Apply the XOR pattern if needed.
if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) {
const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;
offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;
}
// Assemble the final pointer :)
// ptrs[ii] = smem_ + offset + smem_write_buffer_;
// smem_write_buffer_ is already merged with smem_write_offset_
ptrs[ii] = smem_ + offset;
}
}
inline __device__ void debug_reset() {
for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {
for( int row = 0; row < ROWS; ++row ) {
for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {
if( threadIdx.x == 0 ) {
uint32_t val = 0x0;
sts(val, smem_ + row*BYTES_PER_ROW + col + buffer);
}
}
}
}
}
// Print the content of the tile (only for debug ;)).
inline __device__ void debug_print() const {
for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {
for( int row = 0; row < ROWS; ++row ) {
for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {
if( threadIdx.x == 0 ) {
uint32_t val;
lds(val, smem_ + row*BYTES_PER_ROW + col + buffer);
printf("block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n",
blockIdx.x,
blockIdx.y,
blockIdx.z,
smem_,
buffer,
row,
col,
val);
}
}
}
}
}
// Move the read offset to next buffer.
inline __device__ void move_to_next_read_buffer() {
// if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
// this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
// } else if( BUFFERS_PER_TILE > 1 ) {
// this->smem_read_buffer_ += BYTES_PER_BUFFER;
// }
if( BUFFERS_PER_TILE > 1 && smem_read_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
this->smem_read_offset_ -= BYTES_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_read_offset_ += BYTES_PER_BUFFER;
}
}
// Move the read offset to next buffer. TODO: Remove this member function!!!
inline __device__ void move_next_read_buffer() {
this->move_to_next_read_buffer();
}
// Move the read offset to next N buffer (circular-buffer).
inline __device__ void move_to_next_read_buffer(int N) {
if( BUFFERS_PER_TILE > 1 ) {
// this->smem_read_buffer_ += N * BYTES_PER_BUFFER;
// this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
this->smem_read_offset_ += N * BYTES_PER_BUFFER;
this->smem_read_offset_ -= smem_read_offset_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
}
}
// Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!
inline __device__ void move_next_read_buffer(int N) {
this->move_to_next_read_buffer(N);
}
// Move the write offset to next buffer.
inline __device__ void move_to_next_write_buffer() {
// if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
// this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
// } else if( BUFFERS_PER_TILE > 1 ) {
// this->smem_write_buffer_ += BYTES_PER_BUFFER;
// }
if( BUFFERS_PER_TILE > 1 && smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
this->smem_write_offset_ -= BYTES_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_write_offset_ += BYTES_PER_BUFFER;
}
}
// Move the write offset to next buffer. TODO: Remove that member function!
inline __device__ void move_next_write_buffer() {
this->move_to_next_write_buffer();
}
// Move the read offset.
inline __device__ void move_read_offset(int delta) {
this->smem_read_offset_ += delta;
}
// Move the write offset.
inline __device__ void move_write_offset(int delta) {
this->smem_write_offset_ += delta;
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) {
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
// Trying to reduce the shared mem for Q from 4KB per buffer to 2KB per buffer.
if (!PARTIAL_STORE || (tidx_ / THREADS_PER_ROW < ROWS)) {
sts(smem_ptrs, data);
}
}
// Store to the tile in shared memory.
template< int N, int M >
inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) {
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
sts(smem_ptrs, data, preds);
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) {
this->store(data, preds);
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) {
uint32_t tmp[1] = { preds };
this->store(gmem_ptrs, tmp);
}
// The shared memory pointer.
const uint32_t smem_;
// The read offset. Reserve 4 offsets if needed.
int smem_read_offset_;
// The write offset.
int smem_write_offset_;
// The buffer base offset for read.
// int smem_read_buffer_;
// The buffer base offset for write.
// int smem_write_buffer_;
const int tidx_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The layout of the tile.
typename Layout,
// The size of the STS.
int BYTES_PER_STS = 16,
// The number of buffers per tile.
int BUFFERS_PER_TILE = 1,
// Use or not predicates
bool USE_PREDICATES = true
>
struct Smem_tile_a {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K, int MMAS_K_WITH_PADDING >
struct Compute_reset_mask {
// The potential mask.
enum { HALF = MMAS_K_WITH_PADDING / 2 };
// The remainder.
enum { MOD = MMAS_K % HALF };
// The final value.
enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask<MOD, HALF>::VALUE };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K_WITH_PADDING >
struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> {
enum { VALUE = 0 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K >
struct Compute_reset_mask<MMAS_K, MMAS_K> {
enum { VALUE = MMAS_K - 1 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_a {
// The size in bits.
enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A };
// The number of rows.
enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a<N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a<Cta_tile::K>::VALUE
>
struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
Cta_tile::M,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_A,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::M,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_A,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1>;
// The fragment.
using Fragment = Fragment_a<Row>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// Ctor.
inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) {
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_M == 1);
static_assert(WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_K == 1);
static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);
// The row and column read by the thread.
int smem_read_row = (tidx & 0x0f);
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
smem_read_col ^= (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) {
#pragma unroll
for( int mi = 0; mi < Mma_tile::MMAS_M; ++mi ) {
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// Load using LDSM.M88.4.
uint4 tmp;
// ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset);
// Store the value into the fragment.
a[mi].reg(0) = tmp.x;
a[mi].reg(1) = tmp.y;
a[mi].reg(2) = tmp.z;
a[mi].reg(3) = tmp.w;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) {
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) {
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) {
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset() {
// The number of MMAs in the K dimension.
enum { MMAS_K = Mma_tile::MMAS_K };
// The number of MMAs in the K dimension when we include padding.
enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };
// Assemble the mask.
enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_a<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_row_a<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_row_a<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The layout of the tile.
typename Layout,
// The size of the STS.
int BYTES_PER_STS = 16,
// The number of buffers per tile.
int BUFFERS_PER_TILE = 1,
// Use or not predicates
bool USE_PREDICATES = true
>
struct Smem_tile_b {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_b {
// The size in bits.
enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B };
// The number of rows.
enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b<N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b<Cta_tile::K>::VALUE
>
struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
Cta_tile::N,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::N,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1>;
// The fragment.
using Fragment = Fragment_b< Col>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding< Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// The number of STS per thread
enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };
// The number of STS per thread must be at least 1.
enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };
// Ctor.
inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) {
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);
static_assert(WARPS_M == 1);
static_assert(WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_K == 1);
// The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
// The divisor for the warps.
const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
// The row and column read by the thread.
int smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA +
(tidx & 0x07) +
(tidx & 0x10) / 2;
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
smem_read_col ^= (tidx & 0x08) / 8;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// Load using LDSM.M88.4.
uint4 tmp;
// ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
ldsm(tmp, this->smem_ + this->smem_read_offset_ + offset);
// Store the value into the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) {
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) {
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) {
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset() {
// The number of MMAs in the K dimension.
enum { MMAS_K = Mma_tile::MMAS_K };
// The number of MMAs in the K dimension when we include padding.
enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };
// Assemble the mask.
enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_b< Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE >
: public Smem_tile_col_b<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_col_b< Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b< N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b<Cta_tile::N>::VALUE,
// How many cols to use for the XOR pattern to avoid bank conflicts?
int COLS_PER_XOR_PATTERN_ = 1
>
struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
Cta_tile::K,
Cta_tile::N,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
COLS_PER_XOR_PATTERN_> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::K,
Cta_tile::N,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
COLS_PER_XOR_PATTERN_>;
// The fragment.
using Fragment = Fragment_b<Row>;
// Can we use LDSM? No if the data type is 32-bit large.
enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 };
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 };
// The number of elements per LDS.
enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B };
// The number of STS per thread
enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };
// The number of STS per thread must be at least 1.
enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };
// Ctor.
inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) {
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_K == 1);
static_assert(WARPS_M == 4 || WARPS_M == 8);
static_assert(WARPS_N == 1);
// The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
const int WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
static_assert(USE_LDSMT);
static_assert(Base::ROWS_PER_XOR_PATTERN == 2 || Base::ROWS_PER_XOR_PATTERN == 4 || Base::ROWS_PER_XOR_PATTERN == 8);
// The row/col read by the thread.
int smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 +
(tidx & 0x07) + (tidx & 0x08);
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
int smem_read_col = ((smem_read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW_BEFORE_PACKING + smem_read_col*BYTES_PER_LDS;
// Fill zeroes for group conv
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// The size of each element in bits.
const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( BYTES_PER_MMA_PER_CTA >= 128 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
} else if( BYTES_PER_MMA_PER_CTA == 64 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
Mma_tile::MMAS_N % 2 == 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
// The size of each element in bits.
const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
// uint32_t smem_read_og = this->smem_ + this->smem_read_offset_;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Prepare the offset.
int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW_BEFORE_PACKING;
if ( BYTES_PER_MMA_PER_CTA == 32 ) {
offset += this->smem_read_offset_;
} else if ( BYTES_PER_MMA_PER_CTA == 64 ) {
offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2;
} else {
offset += this->smem_read_offset_ + (ni ) * BYTES_PER_MMA_PER_CTA;
}
// Load the data using LDSM.MT88.2.
// uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;
uint32_t ptr = this->smem_ + offset;
uint4 tmp;
if( USE_LDSMT ) {
ldsmt(tmp, ptr);
} else {
lds(tmp.x, (ptr ) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING);
lds(tmp.y, (ptr ) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING);
lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW_BEFORE_PACKING);
lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW_BEFORE_PACKING);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("BYTES_PER_MMA_PER_CTA=%d, ni = %d, smem_read diff = %d\n", BYTES_PER_MMA_PER_CTA, ni, ptr - smem_read_og);
// }
// Store those values in the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if( BYTES_PER_MMA_PER_CTA >= 128 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
} else if( BYTES_PER_MMA_PER_CTA == 64 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 8 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2));
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
Mma_tile::MMAS_N % 2 == 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_b<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_row_b<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_row_b<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, Rows_per_xor_pattern_col_b<Cta_tile::N>::VALUE, 1> {
// The base class.
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, Rows_per_xor_pattern_col_b<Cta_tile::N>::VALUE, 1>;
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The fragment.
using Fragment = Fragment_b< fmha::Col>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// Ctor.
inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) {
// The row/col read by the thread.
int read_row, read_col;
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f);
constexpr int ROWS_PER_PACKING = Base::BYTES_PER_ROW / Base::BYTES_PER_ROW_BEFORE_PACKING;
read_col = ((read_row / ROWS_PER_PACKING) % Base::ROWS_PER_XOR_PATTERN) * Base::COLS_PER_XOR_PATTERN;
read_col ^= (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW_BEFORE_PACKING + read_col * BYTES_PER_LDS;
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Jump by 16 * #warps row.
int row = ki * 16 * Cta_tile::WARPS_K;
// Load the data using LDSM.MT88.2.
uint4 tmp;
fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW_BEFORE_PACKING);
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if( Mma_tile::MMAS_N == 1 ) {
// noop
} else if( Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
} else if( Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if (Mma_tile::MMAS_N == 8) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 4 == 3 ? 14 : (ni % 2 == 1 ? 6 : 2));
} else {
assert(false); // Not implemented!
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile>
struct Smem_tile_o {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The accumulators.
using Accumulator = fmha::Fragment_accumulator;
// The accumulators.
using Data_type = typename Accumulator::Data_type;
// The size of each element.
static constexpr int BYTES_PER_ELEMENT = sizeof(Data_type);
// The size of each STS.
static constexpr int BYTES_PER_STS = 8;
// The size of each row in shared memory.
static constexpr int BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT;
// The size of each LDS.
static constexpr int BYTES_PER_LDS = 16;
static constexpr int THREADS_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT / BYTES_PER_LDS;
// The number of rows.
static constexpr int ROWS = Cta_tile::M;
// The number of "rows" to process per loop iteration (in the "epilogue").
static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA;
// The number of outer loops.
static constexpr int LOOPS = ROWS / ROWS_PER_LOOP;
// Make sure it matches our expectations.
static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, "");
// The number of rows loaded per LDS.
static constexpr int ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
// Do we have to guard against partial writes/reads.
static constexpr bool HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0;
// The total number of LDS per loop.
static constexpr int LDS_PER_LOOP = fmha::DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_LDS);
// The amount of shared memory.
static constexpr int BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW;
// The write pointer.
uint32_t smem_write_, smem_read_;
// Is the thread active for the last LDS of the series?
int is_active_for_last_lds_;
// static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K);
static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, "");
// Ctor.
inline __device__ Smem_tile_o(void *smem, int tidx) {
// Get a 32-bit value for the shared memory address.
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
static_assert(Cta_tile::N == 16 || Cta_tile::N == 32 || Cta_tile::N == 64 || Cta_tile::N == 128);
int write_row = (tidx & 0x1c) / 4;
const int lane = tidx % 32;
const int warp = tidx / 32;
constexpr int ELEMENTS_PER_STS = BYTES_PER_STS / BYTES_PER_ELEMENT;
constexpr int STS_PER_WARP = 16 * Mma_tile::MMAS_N / ELEMENTS_PER_STS;
int write_col = warp * STS_PER_WARP + lane % STS_PER_WARP;
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("write_row = %d, write_col = %d\n", write_row, write_col);
// }
// if ((blockIdx.x == 0) && (blockIdx.y == 0) && (write_row == 0) && (write_col == 0)) {
// printf("threadIdx.x = %d\n", threadIdx.x);
// }
// Assemble the write pointer.
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
// The element read by each thread.
int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
// Take the XOR pattern into account for the column.
read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : 8)));
// read_col ^= 2 * (read_row % (Cta_tile::N == 16 ? 2 : (Cta_tile::N == 32 ? 4 : (Cta_tile::N == 128 ? 16 : 8))));
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("read_row = %d, read_col = %d\n", read_row, read_col);
// }
// if ((blockIdx.x == 0) && (blockIdx.y == 0) && (read_row == 0) && (read_col == 0)) {
// printf("threadIdx.x = %d\n", threadIdx.x);
// }
// Assemble the read pointer.
this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// Is that thread active on the last LDS?
if( HAS_INCOMPLETE_LDS ) {
this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M;
}
}
// Load the output fragments.
template <bool zero_init=true>
inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {
#pragma unroll
for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {
// Load the elements before the reduction (split-K).
uint4 tmp[Cta_tile::WARPS_K];
#pragma unroll
for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
uint32_t smem_read = this->smem_read_ + imm;
// TD [2022-06-05] Ugly fix for d=128 in the forward pass, maybe there's a better way.
if ((Cta_tile::N == 128) && (ROWS_PER_LDS == 4) && (ii % 2 == 1)) {
smem_read ^= 8 * BYTES_PER_LDS;
}
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("imm diff = %d\n", smem_read - this->smem_read_);
// }
if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {
// fmha::lds(tmp[jj], this->smem_read_ + imm);
fmha::lds(tmp[jj], smem_read);
}
}
// Perform the reduction.
out[ii] = zero_init ? tmp[0] : fmha::fadd4(out[ii], tmp[0]);
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("out reduction: out = %.6f\n", reinterpret_cast<float (&)[4]>(out[ii])[0]);
// }
#pragma unroll
for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {
out[ii] = fmha::fadd4(out[ii], tmp[jj]);
// if ((threadIdx.x == 8) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("out reduction tmp = %.6f, out = %.6f\n", reinterpret_cast<float (&)[4]>(tmp[jj])[0], reinterpret_cast<float (&)[4]>(out[ii])[0]);
// }
}
}
}
// Store the accumulators.
template <int M, int N>
inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {
// uint32_t smem_write_og = this->smem_write_;
static constexpr int M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// The number of MMAs that are stored per loop iteration.
static constexpr int MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS;
// Store 1st column of the different MMAs.
#pragma unroll
for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
// Precompute the immediates to jump between rows.
int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
uint2 tmp0, tmp1;
tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0);
tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1);
tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2);
tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3);
// Store.
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
// }
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// uint4 read_tmp;
// fmha::lds(read_tmp, this->smem_read_);
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
// }
// Swizzle the write pointer using a XOR of 16B.
this->smem_write_ ^= 32;
// Store 2nd column of the different MMAs.
#pragma unroll
for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
// Precompute the immediates to jump between rows.
int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
uint2 tmp0, tmp1;
tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4);
tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5);
tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6);
tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7);
// Store.
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
// }
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
static_assert(Mma_tile::MMAS_N <= 8, "Not implemented");
if( Mma_tile::MMAS_N >= 8 && ni % 4 == 3 ) {
this->smem_write_ ^= 15 * 32;
} else if( Mma_tile::MMAS_N >= 4 && ni % 2 == 1 ) {
this->smem_write_ ^= 7 * 32;
} else if( Mma_tile::MMAS_N >= 2 ) {
this->smem_write_ ^= 3 * 32;
} else {
this->smem_write_ ^= 3 * 32;
}
// this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// uint4 read_tmp;
// fmha::lds(read_tmp, this->smem_read_);
// printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
// }
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_mma {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
using Fragment = fmha::Fragment_a<fmha::Col>;
enum { COLS = Cta_tile::N };
enum { BYTES_PER_ELT = 2 };
enum { BYTES_PER_STS = 4 };
enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO
enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
enum { WARPS_K = Cta_tile::WARPS_K };
static_assert(WARPS_K == 1);
inline __device__ Smem_tile_mma(char *smem, int tidx) {
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
int write_col, write_row;
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_M == 8) || WARPS_N == 1);
if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {
write_row = (tidx & 0x1c) / 4;
write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
write_col ^= (write_row & 0x07) * 4;
} else {
write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;
write_col = (tidx & 0x03);
// write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x0f)))) * 4;
write_col ^= (write_row & (BYTES_PER_ROW == 32 ? 0x01 : (BYTES_PER_ROW == 64 ? 0x03 : (BYTES_PER_ROW == 128 ? 0x07 : 0x07)))) * 4;
}
// write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
}
template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) {
static_assert(COLS == Cta_tile::N);
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
// fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
// fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
// offset ^= 4 * BYTES_PER_STS;
// fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
// fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
// size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * BYTES_PER_STS;
fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
}
}
}
template<typename Fragment, int M, int N>
inline __device__ void store(const Fragment (&frag)[N][M]) {
static_assert(COLS == Cta_tile::N);
uint4 regs[M][N];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// Need to transpose ref(1) and reg(2) here since when we load it we transpose again.
regs[mi][ni] = make_uint4(frag[ni][mi].reg(0), frag[ni][mi].reg(2),
frag[ni][mi].reg(1), frag[ni][mi].reg(3));
}
}
this->store(regs);
}
// uint32_t smem_;
// uint32_t write_offset_;
uint32_t smem_write_;
};
template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>
struct Smem_tile_mma_transposed : public Base {
enum { BYTES_PER_LDS = 16 };
enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };
enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
using Fragment = typename Base::Fragment;
inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) {
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
int read_row, read_col;
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;
// read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x0f))));
read_col ^= (read_row & 0x07);
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
template<int M, int N>
inline __device__ void load(Fragment (&frag)[M][N]) {
static_assert(Base::COLS == Cta_tile::N);
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
// fmha::ldsmt(dst, this->smem_ + offset);
// size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::ldsmt(dst, offset);
frag[mi][ni].reg(0) = dst.x;
frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major!
frag[mi][ni].reg(2) = dst.y;
frag[mi][ni].reg(3) = dst.w;
}
}
}
// uint32_t read_offset_;
uint32_t smem_read_;
};
template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>
struct Smem_tile_mma_epilogue : public Base {
enum { BYTES_PER_LDS = 16 };
enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };
enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS };
static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW);
enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS };
static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M);
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
using Acc = fmha::Fragment_accumulator;
inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) {
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
const int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
// read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : 0x07)));
static_assert(Base::BYTES_PER_ROW == 32 || Base::BYTES_PER_ROW == 64 || Base::BYTES_PER_ROW == 128 || Base::BYTES_PER_ROW == 256);
read_col ^= (read_row & (Base::BYTES_PER_ROW == 32 ? 0x01 : (Base::BYTES_PER_ROW == 64 ? 0x03 : (Base::BYTES_PER_ROW == 128 ? 0x07 : 0x07))));
// read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
inline __device__ void load(uint4 (&data)[NUM_LDS]) {
for( int ii = 0; ii < NUM_LDS; ii++ ) {
// size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
// fmha::lds(data[ii], this->smem_ + offset);
// size_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
uint32_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
fmha::lds(data[ii], offset);
}
}
template<typename elem_type=__half, int M, int N>
inline __device__ void store(const Acc (&acc)[M][N]){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// 1st row - 4 elements per row.
float tmp00 = acc[mi][ni].elt(0);
float tmp01 = acc[mi][ni].elt(1);
float tmp02 = acc[mi][ni].elt(4);
float tmp03 = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc[mi][ni].elt(2);
float tmp11 = acc[mi][ni].elt(3);
float tmp12 = acc[mi][ni].elt(6);
float tmp13 = acc[mi][ni].elt(7);
uint32_t x = fmha::float2_pack<elem_type>(tmp00, tmp01);
uint32_t y = fmha::float2_pack<elem_type>(tmp02, tmp03);
uint32_t z = fmha::float2_pack<elem_type>(tmp10, tmp11);
uint32_t w = fmha::float2_pack<elem_type>(tmp12, tmp13);
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);
// offset ^= 4 * Base::BYTES_PER_STS;
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
// size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_);
// }
fmha::sts(offset + 0 * BYTES_PER_ROW, x);
fmha::sts(offset + 8 * BYTES_PER_ROW, z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(offset + 0 * BYTES_PER_ROW, y);
fmha::sts(offset + 8 * BYTES_PER_ROW, w);
}
}
}
template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) {
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
}
}
}
// uint32_t read_offset_;
uint32_t smem_read_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_transpose {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
using Fragment_write = fmha::Fragment_b<fmha::Col>;
using Fragment_read = fmha::Fragment_b<fmha::Col>;
enum { COLS = Cta_tile::N };
enum { BYTES_PER_ELT = 2 };
enum { BYTES_PER_STS = 4 };
enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO
enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };
enum { BYTES_PER_LDS = 16 };
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
enum { WARPS_K = Cta_tile::WARPS_K };
static_assert(WARPS_K == 1);
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
inline __device__ Smem_tile_transpose(char *smem, int tidx) {
smem_ = __nvvm_get_smem_pointer(smem);
// uint32_t smem_ = __nvvm_get_smem_pointer(smem);
int write_col, write_row;
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {
write_row = (tidx & 0x1c) / 4;
write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
} else {
write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;
write_col = (tidx & 0x03);
}
write_col ^= (write_row & 0x07) * 4;
write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
// smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
int read_row, read_col;
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;
read_col ^= (read_row & 0x07);
read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
template<int M, int N>
inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
}
}
template<int N>
inline __device__ void load(Fragment_read (&frag_r)[N]) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
frag_r[ni].reg(0) = dst.x;
frag_r[ni].reg(1) = dst.y; // Fragment B regs col major!
frag_r[ni].reg(2) = dst.z;
frag_r[ni].reg(3) = dst.w;
}
}
template<int M, int N>
inline __device__ void transpose(const Fragment_write (&frag_w)[M][N], Fragment_read (&frag_r)[M], int mi) {
static_assert(COLS == Cta_tile::N);
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
}
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
frag_r[ni].reg(0) = dst.x;
frag_r[ni].reg(1) = dst.y; // Fragment B regs col major!
frag_r[ni].reg(2) = dst.z;
frag_r[ni].reg(3) = dst.w;
}
}
uint32_t smem_;
uint32_t write_offset_;
uint32_t read_offset_;
// uint32_t smem_write_;
// uint32_t smem_read_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename Gmem_tile,
// The number of buffers. (Used in multistage and double buffer cases.)
int BUFFERS_PER_TILE_ = 1
>
struct Smem_tile_dp_sum {
using Cta_tile = typename Gmem_tile::Cta_tile;
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The size of each element.
static constexpr int BYTES_PER_ELEMENT = 4;
static constexpr int ROWS = Gmem_tile::ROWS;
static constexpr int THREADS_PER_ROW = Gmem_tile::THREADS_PER_ROW;
static constexpr int MMAS_M = Mma_tile::MMAS_M;
static constexpr int ROWS_PER_LDG = Gmem_tile::ROWS_PER_LDG;
static constexpr int LDGS = Gmem_tile::LDGS;
static constexpr int ROWS_PER_MMA = Mma_tile::M_PER_MMA;
// The size of one buffer in bytes in shared memory.
static constexpr int BYTES_PER_BUFFER = ROWS * BYTES_PER_ELEMENT;
// The number of buffers.
static constexpr int BUFFERS_PER_TILE = BUFFERS_PER_TILE_;
// The size in bytes of total buffers.
static constexpr int BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE;
// The boundary for smem_read_offset and smem_write_offset increment.
static constexpr int ROWS_PER_TILE_INC_BOUNDARY = ROWS * BUFFERS_PER_TILE - ROWS;
inline __device__ Smem_tile_dp_sum(float *smem, const int tidx)
: smem_(smem), smem_read_buffer_(smem), smem_write_buffer_(smem), tidx_(tidx) {
}
// Move the read offset to next buffer.
inline __device__ void move_to_next_read_buffer() {
if( BUFFERS_PER_TILE > 1 && (smem_read_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY ) {
this->smem_read_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_read_buffer_ += ROWS;
}
}
// Move the write offset to next buffer.
inline __device__ void move_to_next_write_buffer() {
if( BUFFERS_PER_TILE > 1 && (smem_write_buffer_ - smem_) >= ROWS_PER_TILE_INC_BOUNDARY ) {
this->smem_write_buffer_ -= ROWS_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_write_buffer_ += ROWS;
}
}
inline __device__ void store(const float (&sum)[LDGS]) {
if (tidx_ % THREADS_PER_ROW == 0) {
int row = tidx_ / THREADS_PER_ROW;
#pragma unroll
for (int i = 0; i < LDGS; ++i) {
if (row + i * ROWS_PER_LDG < ROWS) {
smem_write_buffer_[row + i * ROWS_PER_LDG] = sum[i];
}
}
}
}
inline __device__ void store(const float sum, const int buffer_idx) {
float *smem_write = smem_ + buffer_idx * ROWS;
int row = tidx_ / THREADS_PER_ROW;
if ((row < ROWS) && (tidx_ % THREADS_PER_ROW == 0)) {
smem_write[row] = sum;
}
}
inline __device__ void store(const float (&sum)[LDGS], const int buffer_idx) {
float *smem_write = smem_ + buffer_idx * ROWS;
if (tidx_ % THREADS_PER_ROW == 0) {
int row = tidx_ / THREADS_PER_ROW;
#pragma unroll
for (int i = 0; i < LDGS; ++i) {
if (row + i * ROWS_PER_LDG < ROWS) {
smem_write[row + i * ROWS_PER_LDG] = sum[i];
}
}
}
}
inline __device__ void store_pair(const float (&sum)[MMAS_M * 2]) {
float *smem_write = smem_;
// Extract the position in the warp.
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
int row = lane / 4;
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0];
smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1];
}
}
inline __device__ void store_pair(const float (&sum)[MMAS_M * 2], const int buffer_idx) {
float *smem_write = smem_ + buffer_idx * ROWS;
// Extract the position in the warp.
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
int row = lane / 4;
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
smem_write[mi * ROWS_PER_MMA + row + 0] = sum[mi * 2 + 0];
smem_write[mi * ROWS_PER_MMA + row + 8] = sum[mi * 2 + 1];
}
}
template<int N>
inline __device__ void load(float (&sum)[N], const int (&row)[N]) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
sum[ni] = smem_read_buffer_[row[ni]];
}
}
template<int N>
inline __device__ void load(float (&sum)[N], const int (&row)[N], const int buffer_idx) {
float *smem_read = smem_ + buffer_idx * ROWS;
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
sum[ni] = smem_read[row[ni]];
}
}
static inline __device__ float reduce_warp(float sum) {
fmha::SumOp<float> sum_op;
return fmha::Allreduce<THREADS_PER_ROW>::run(sum, sum_op);
}
const int tidx_;
float * const smem_;
float *smem_read_buffer_;
float *smem_write_buffer_;
};
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cmath>
#include <cuda_fp16.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float apply_exp_(float x, float max) {
return __expf(x - max);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float apply_exp2_(float x, float max) {
return exp2f(x - max);
// With fast-math, this produces the same PTX instruction as the assembly below
// float diff = x - max;
// float res;
// asm ("ex2.approx.ftz.f32 %0, %1;\n\t" : "=f"(res) : "f"(diff));
// return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int COLS> struct ReadType {};
template<> struct ReadType<4> { using T = float;};
template<> struct ReadType<8> { using T = float2;};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile, typename Kernel_traits>
struct Smem_tile_reduce {
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
static constexpr int MMAS_N = Mma_tile::MMAS_N;
static constexpr int WARPS_M = Cta_tile::WARPS_M;
static constexpr int WARPS_N = Cta_tile::WARPS_N;
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
static constexpr int COLS = WARPS_N;
static_assert(COLS == 4 || COLS == 8);
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
static constexpr int ELTS_PER_TILE = ROWS * COLS;
static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW;
// TD [2022-05-02]: No longer true if head_dim != 64
// static_assert(THREADS_PER_GROUP == 16); // DEBUG
static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP;
static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
static_assert(LOOPS == 1);
using read_t = typename ReadType<COLS>::T;
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
int lane = tidx % 32;
int warp = tidx / 32;
int warp_m = warp % WARPS_M;
int warp_n = warp / WARPS_M;
qid_ = lane % 4;
int qp = lane / 4;
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
// This won't affect reading as we assume commutative reduction ops.
const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
smem_read_row_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qid_];
}
__device__ inline void store(float (&frag)[2 * MMAS_M]) {
if( qid_ == 0 ) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * WARPS_N;
smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];
smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];
}
}
}
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];
frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];
}
}
__device__ inline void load_row(read_t (&frag)[MMAS_M], int row) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi] = smem_read_row_[offset + 0 * 8 * 4 + row * 4];
}
}
int qid_;
float *smem_write_;
read_t *smem_read_;
read_t *smem_read_row_;
};
template<typename Cta_tile, typename Kernel_traits>
struct Softmax_base {
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
static constexpr int MMAS_N = Mma_tile::MMAS_N;
// The number of groups of warp such that we have at most 4 warps writing consecutive elements.
static constexpr int GROUPS = fmha::DivUpConstexpr(Cta_tile::WARPS_N, 4);
// The number of elements that we are going to store per row.
static constexpr int ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS;
// The number of rows.
static constexpr int ROWS = Cta_tile::M * GROUPS;
// The total number of elements.
static constexpr int ELEMENTS = ROWS * ELEMENTS_PER_ROW;
// Ctor.
template<typename Params>
inline __device__ Softmax_base(const Params &params, void *smem, int tidx)
: // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
smem_(reinterpret_cast<float *>(smem)), tidx_(tidx) {
// Move to the 1st mask loaded by the thread+ tidx;
// packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);
// Extract the position in the warp.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// Decompose the warp index into M and N.
int warp_m = warp % Cta_tile::WARPS_M;
int warp_n = warp / Cta_tile::WARPS_M;
// Decompose the warp-n index into group/position-inside-the-group.
int warp_g = warp_n / ELEMENTS_PER_ROW;
int warp_i = warp_n % ELEMENTS_PER_ROW;
// The location written by the threads.
int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;
int write_col = warp_i;
// Assemble the write pointer.
smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];
// Assemble the read pointer.
smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
}
template<bool zero=false, typename Mask>
inline __device__ void apply_mask(const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ii = 0; ii < 2; ++ii ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
#pragma unroll
for( int jj = 0; jj < 4; ++jj ) {
if( !mask.is_valid(mi, ni, ii, jj) ) {
elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY;
}
}
}
}
}
}
// Apply the exp to all the elements.
template <bool max_in_base2=false, bool elt_in_base2=false>
inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
constexpr float kLog2e = M_LOG2E;
const float max_base2 = max_in_base2 ? max[mi] : max[mi] * kLog2e;
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
// elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);
elt_[mi][ni] = apply_exp2_(elt_in_base2 ? elt_[mi][ni] : elt_[mi][ni] * kLog2e,
max_base2);
}
}
}
// Apply the exp to all the elements.
template <bool scale_max=true>
inline __device__ void scale_apply_exp(const float (&max)[MMAS_M * 2], const float scale_) {
const float max_scale = scale_max ? scale_ * M_LOG2E : M_LOG2E;
const float scale = scale_ * M_LOG2E;
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
const float max_scaled = max[mi] * max_scale;
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * scale, max_scaled);
}
}
}
// Apply the exp to all the elements.
template <bool max_in_base2=false>
inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
constexpr float kLog2e = M_LOG2E;
const float max_base2 = max_in_base2 ? max[ni] : max[ni] * kLog2e;
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
}
}
}
// inline __device__ void apply_exp_col(const float (&max)[MMAS_N]) {
// constexpr float kLog2e = M_LOG2E;
// #pragma unroll
// for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
// float max_base2 = max_in_base2 ? max[ni / 4] : max[ni / 4] * kLog2e;
// max_base2 = __shfl_sync(0xffffffff, max_base2, (ni % 4) * 8 + threadIdx.x % 8);
// #pragma unroll
// for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
// elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
// }
// }
// }
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni++ ) {
uint16_t tmp[8];
// fmha::uint4_to_ushort8(ph(), tmp);
uint4 tmp_32 = ph();
fmha::uint4_to_ushort8(tmp_32, tmp);
// if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("tidx = %d, ni = %d, ph Philox: %u, %u, %u, %u\n", threadIdx.x, ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
elt_[mi * 2 + ii][4 * ni + jj] =
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
}
}
}
}
}
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t,
unsigned long long philox_subsequence) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
static_assert(MMAS_M == 1); // We're assuming 16x16 blocks.
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni++ ) {
uint16_t tmp[8];
// fmha::uint4_to_ushort8(ph(), tmp);
fmha::uint4_to_ushort8(ph(philox_subsequence + ni * Cta_tile::WARPS_N), tmp);
// uint4 tmp_32 = ph(philox_subsequence + ni * Cta_tile::WARPS_N);
// fmha::uint4_to_ushort8(tmp_32, tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
elt_[mi * 2 + ii][4 * ni + jj] =
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
}
}
}
}
}
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout_16bits(Philox &ph0, Philox &ph1, uint16_t p_dropout_in_uint16_t) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
static_assert(MMAS_N % 2 == 0);
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni += 2 ) {
uint16_t tmp[8];
fmha::uint4_to_ushort8(ph0(), tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
elt_[mi * 2 + ii][4 * ni + jj] =
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
}
}
fmha::uint4_to_ushort8(ph1(), tmp);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
elt_[mi * 2 + ii][4 * (ni + 1) + jj] =
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * (ni + 1) + jj]);
}
}
}
}
}
// Scale all the elements.
inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
float inv_sum[MMAS_M * 2];
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];
}
// Update the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] *= inv_sum[mi];
}
}
}
// Subtract all elements by dp_sum
inline __device__ void subtract_dp_sum(const float (&dp_sum)[MMAS_M * 2]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] -= dp_sum[mi];
}
}
}
// The pointer to the mask.
const char *packed_mask_ptr_;
// Shared memory for the CTA-wide reduction.
float *smem_, *smem_write_, *smem_read_;
// The current thread index.
int tidx_;
// The elements.
float elt_[MMAS_M * 2][MMAS_N * 4];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile, typename Kernel_traits>
struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
// The base class.
using Base = Softmax_base<Cta_tile, Kernel_traits>;
// The fragment.
using Fragment_a = fmha::Fragment_a<fmha::Row>;
static_assert(Fragment_a::NUM_REGS == 4);
static constexpr int WARPS_M = Cta_tile::WARPS_M;
static constexpr int WARPS_N = Cta_tile::WARPS_N;
// The MMAs.
static constexpr int MMAS_M = Base::MMAS_M;
static constexpr int MMAS_N = Base::MMAS_N;
// The accumulators.
using Accumulator = fmha::Fragment_accumulator;
using Accumulator_out = Fragment<uint16_t, 8>;
static_assert(Accumulator_out::NUM_REGS == 4);
static_assert(std::is_same<Accumulator::Data_type, float>::value);
using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
// Ctor.
template<typename Params>
inline __device__ Softmax(const Params &params, void *smem, int tidx)
: Base(params, smem, tidx)
, params_scale_bmm1_(params.scale_bmm1)
, smem_sum_(static_cast<float*>(smem), tidx)
, smem_max_(static_cast<float*>(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {
}
// Pack the data to a fragment for the next GEMM.
template<typename elem_type=__half, int K, int M>
inline __device__ void pack(Fragment_a (&dst)[K][M]) const {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ki = 0; ki < K; ++ki ) {
// 1st row - 4 elements per row.
float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];
float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];
float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2];
float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3];
// 2nd row - 4 elements per row.
float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0];
float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1];
float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2];
float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];
// Pack to 4 registers.
dst[ki][mi].reg(0) = fmha::float2_pack<elem_type>(tmp_00, tmp_01);
dst[ki][mi].reg(1) = fmha::float2_pack<elem_type>(tmp_10, tmp_11);
dst[ki][mi].reg(2) = fmha::float2_pack<elem_type>(tmp_02, tmp_03);
dst[ki][mi].reg(3) = fmha::float2_pack<elem_type>(tmp_12, tmp_13);
}
}
}
// Scale FP32 fragments
inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
const float scalef = reinterpret_cast<const float &>(this->params_scale_bmm1_);
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef;
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef;
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef;
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef;
// 2nd row - 4 elements per row.
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef;
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef;
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef;
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef;
}
}
}
// Scale FP32 fragments
inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4);
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2);
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);
}
}
}
template<bool zero_init=true, typename Operator>
__device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) {
#pragma unroll
for( int mi = 0; mi < 2 * MMAS_M; mi++ ) {
frag[mi] = zero_init ? this->elt_[mi][0] : op(frag[mi], this->elt_[mi][0]);
#pragma unroll
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
frag[mi] = op(frag[mi], this->elt_[mi][ni]);
}
}
}
template<bool zero_init=true, typename Operator>
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
thread_reduce_<zero_init>(frag, op);
quad_reduce(frag, frag, op);
smem_red.store(frag);
__syncthreads();
typename Smem_tile_red::read_t tmp[2 * MMAS_M];
smem_red.load(tmp);
quad_allreduce(frag, tmp, op);
}
template<bool zero_init=true>
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
MaxOp<float> max;
reduce_<zero_init>(frag, max, smem_max_);
}
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
SumOp<float> sum;
reduce_(frag, sum, smem_sum_);
}
template<bool zero_init=true>
__device__ inline void reduce_sum_before_sync_(float (&frag)[2 * MMAS_M]){
SumOp<float> sum;
thread_reduce_<zero_init>(frag, sum);
quad_reduce(frag, frag, sum);
smem_sum_.store(frag);
}
template<int NROWS, typename Operator>
__device__ inline void reduce_after_sync_(float (&frag)[NROWS][MMAS_M],
const int (&rows)[NROWS],
Operator &op, Smem_tile_red & smem_red) {
#pragma unroll
for (int ii = 0; ii < NROWS; ii++) {
typename Smem_tile_red::read_t tmp[MMAS_M];
smem_red.load_row(tmp, rows[ii]);
quad_allreduce(frag[ii], tmp, op);
}
}
template<int NROWS>
__device__ inline void reduce_sum_after_sync_(float (&frag)[NROWS][MMAS_M],
const int (&rows)[NROWS]){
SumOp<float> sum;
reduce_after_sync_(frag, rows, sum, smem_sum_);
}
template<int NROWS>
__device__ inline void reduce_max_after_sync_(float (&frag)[NROWS][MMAS_M],
const int (&rows)[NROWS]){
MaxOp<float> max;
reduce_after_sync_(frag, rows, max, smem_max_);
}
const uint32_t params_scale_bmm1_;
Smem_tile_red smem_max_;
Smem_tile_red smem_sum_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr);
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Row {};
struct Col {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, bool = (M & (M-1)) == 0 >
struct Next_power_of_two {
};
template< int M >
struct Next_power_of_two< M, true > { enum { VALUE = M }; };
template<>
struct Next_power_of_two< 3, false> { enum { VALUE = 4 }; };
template<>
struct Next_power_of_two< 5, false> { enum { VALUE = 8 }; };
template<>
struct Next_power_of_two< 6, false> { enum { VALUE = 8 }; };
template<>
struct Next_power_of_two< 7, false> { enum { VALUE = 8 }; };
template<>
struct Next_power_of_two< 9, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 10, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 11, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 12, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 13, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 14, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 15, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 24, false> { enum { VALUE = 32 }; };
template<>
struct Next_power_of_two< 48, false> { enum { VALUE = 64 }; };
template<>
struct Next_power_of_two< 80, false> { enum { VALUE = 128 }; };
template<>
struct Next_power_of_two< 96, false> { enum { VALUE = 128 }; };
template<>
struct Next_power_of_two<112, false> { enum { VALUE = 128 }; };
template<>
struct Next_power_of_two<144, false> { enum { VALUE = 256 }; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, bool = (N & (N-1)) == 0 >
struct Prev_power_of_two {
};
template< int N >
struct Prev_power_of_two< N, true > { enum { VALUE = N }; };
template<>
struct Prev_power_of_two< 3, false> { enum { VALUE = 2 }; };
template<>
struct Prev_power_of_two< 5, false> { enum { VALUE = 4 }; };
template<>
struct Prev_power_of_two< 6, false> { enum { VALUE = 4 }; };
template<>
struct Prev_power_of_two< 7, false> { enum { VALUE = 4 }; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, int N >
struct Div_up {
enum { VALUE = (M + N-1) / N };
};
constexpr int DivUpConstexpr(int M, int N) { return (M + N - 1) / N; }
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B >
struct Max {
enum { VALUE = A >= B ? A : B };
};
constexpr int MaxConstexpr(int A, int B) { return A >= B ? A : B; }
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B, int C >
struct Max_3 {
enum { VALUE = Max<Max<A, B>::VALUE, C>::VALUE };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B >
struct Min {
enum { VALUE = A <= B ? A : B };
};
constexpr int MinConstexpr(int A, int B) { return A <= B ? A : B; }
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int SIZE_IN_BYTES >
struct Uint_from_size_in_bytes {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<1> {
using Type = uint8_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<2> {
using Type = uint16_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<4> {
using Type = uint32_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<8> {
using Type = uint2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<16> {
using Type = uint4;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int WARPS_M, int WARPS_N, int WARPS_K >
struct Warp_masks {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; };
template<>
struct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; };
template<>
struct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; };
template<>
struct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; };
template<>
struct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; };
template<>
struct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; };
template<>
struct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; };
template<>
struct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; };
template<>
struct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; };
template<>
struct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; };
template<>
struct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; };
template<>
struct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; };
template<>
struct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; };
template<>
struct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; };
template<>
struct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; };
template<>
struct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; };
template<>
struct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename T >
inline __device__ __host__ T div_up(T m, T n) {
return (m + n-1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int clz(int x) {
for( int i = 31; i >= 0; --i ) {
if( (1 << i) & x ) {
return 31 - i;
}
}
return 32;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int find_log_2(int x, bool round_up = false) {
int a = 31 - clz(x);
if( round_up ) {
a += (x & (x-1)) ? 1 : 0;
}
return a;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) {
// uint32_t c;
// asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
// return c;
__half2 result = __hmul2(reinterpret_cast<const __half2 (&)>(a),
reinterpret_cast<const __half2 (&)>(b));
return reinterpret_cast<uint32_t(&)>(result);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hmul4(uint2 a, uint2 b) {
uint2 c;
c.x = hmul2(a.x, b.x);
c.y = hmul2(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hmul8(uint4 a, uint4 b) {
uint4 c;
c.x = hmul2(a.x, b.x);
c.y = hmul2(a.y, b.y);
c.z = hmul2(a.z, b.z);
c.w = hmul2(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hmul8(uint32_t a, uint4 b) {
uint4 c;
c.x = hmul2(a, b.x);
c.y = hmul2(a, b.y);
c.z = hmul2(a, b.z);
c.w = hmul2(a, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ uint32_t hrelu2(uint32_t x);
template<>
inline __device__ uint32_t hrelu2<__half>(uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
#else
asm volatile( \
"{\n" \
"\t .reg .f16x2 sela;\n" \
"\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
"\t and.b32 %0, sela, %1;\n"
"}\n" : "=r"(res) : "r"(x), "r"(zero));
#endif
return res;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
inline __device__ uint32_t hrelu2<__nv_bfloat16>(uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
asm volatile( "max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
return res;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t habs2(uint32_t x) {
uint32_t res;
asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x));
return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename T >
static inline __device__ T clamp(T x, T lb, T ub) {
return x < lb ? lb : (x > ub ? ub : x);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t clamp_to_zero(uint16_t x) {
uint16_t mask;
asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x));
return mask & x;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t float_to_half(float f) {
uint16_t h;
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f));
return h;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float2_to_half2(float a, float b) {
uint32_t c;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a));
#else
uint16_t lo = float_to_half(a);
uint16_t hi = float_to_half(b);
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi));
#endif
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ uint32_t float2_pack(float a, float b);
template <>
inline __device__ uint32_t float2_pack<__half>(float a, float b) {
__half2 result = __floats2half2_rn(a, b);
return reinterpret_cast<uint32_t(&)>(result);
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
inline __device__ uint32_t float2_pack<__nv_bfloat16>(float a, float b) {
__nv_bfloat162 result = __floats2bfloat162_rn(a, b);
return reinterpret_cast<uint32_t(&)>(result);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float_to_half2(float a) {
return float2_to_half2(a,a);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float2_to_half2(const float2 &f) {
return float2_to_half2(f.x, f.y);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) {
uint2 d;
d.x = float2_to_half2(x, y);
d.y = float2_to_half2(z, w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ uint2 float4_pack(float x, float y, float z, float w) {
uint2 d;
d.x = float2_pack<T>(x, y);
d.y = float2_pack<T>(z, w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c));
#else
d = hrelu2<__half>(hfma2(a, b, c));
#endif
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t h0_h0(uint32_t x) {
uint32_t y;
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n"
: "=r"(y) : "r"(x));
return y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float h0_to_float(uint32_t h2) {
float f;
asm volatile("{\n" \
".reg .f16 lo, hi;\n" \
"mov.b32 {lo, hi}, %1;\n" \
"cvt.f32.f16 %0, lo;\n" \
"}\n" : "=f"(f) : "r"(h2));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t h1_h1(uint32_t x) {
uint32_t y;
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n"
: "=r"(y) : "r"(x));
return y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) {
uint16_t d;
asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) {
return hadd2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hadd4(uint2 a, uint2 b) {
uint2 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hadd(uint2 a, uint2 b) {
return hadd4(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hadd8(uint4 a, uint4 b) {
uint4 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
c.z = hadd2(a.z, b.z);
c.w = hadd2(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ float2 half2_unpack(uint32_t a);
template <>
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
return __half22float2(reinterpret_cast<__half2 (&)>(a));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two half2's or bf162's into float, then take their dot product.
template <typename T>
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
float2 af = fmha::half2_unpack<T>(a);
float2 bf = fmha::half2_unpack<T>(b);
return af.x * bf.x + af.y * bf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
template<typename T>
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
float sum;
sum = fmha::hfma2_to_float<T>(a.x, b.x);
sum += fmha::hfma2_to_float<T>(a.y, b.y);
sum += fmha::hfma2_to_float<T>(a.z, b.z);
sum += fmha::hfma2_to_float<T>(a.w, b.w);
return sum;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 fadd4(uint4 a, uint4 b) {
float4 c;
c.x = reinterpret_cast<const float&>(a.x) + reinterpret_cast<const float&>(b.x);
c.y = reinterpret_cast<const float&>(a.y) + reinterpret_cast<const float&>(b.y);
c.z = reinterpret_cast<const float&>(a.z) + reinterpret_cast<const float&>(b.z);
c.w = reinterpret_cast<const float&>(a.w) + reinterpret_cast<const float&>(b.w);
return reinterpret_cast<const uint4&>(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 fmul4(uint4 a, float b) {
float4 c;
c.x = reinterpret_cast<const float &>(a.x) * b;
c.y = reinterpret_cast<const float &>(a.y) * b;
c.z = reinterpret_cast<const float &>(a.z) * b;
c.w = reinterpret_cast<const float &>(a.w) * b;
return reinterpret_cast<const uint4 &>(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hadd(uint4 a, uint4 b) {
return hadd8(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float half_to_float(uint16_t h) {
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float2 half2_to_float2(uint32_t x) {
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x));
return make_float2(half_to_float(lo), half_to_float(hi));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void half2_to_float2(float &x, float &y, uint32_t h) {
float2 tmp = half2_to_float2(h);
x = tmp.x;
y = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) {
uint16_t d;
asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) {
uint16_t d;
asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void uint4_to_ushort8(const uint4 a, uint16_t (&b)[8]) {
uint32_t *b_tmp = reinterpret_cast<uint32_t *>(&b[0]);
b_tmp[0] = a.x;
b_tmp[1] = a.y;
b_tmp[2] = a.z;
b_tmp[3] = a.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float sigmoid(float x) {
return 1.f / (1.f + expf(-x));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint16_t &dst) {
dst = uint16_t(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint32_t &dst) {
dst = 0u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint2 &dst) {
dst = make_uint2(0u, 0u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint4 &dst) {
dst = make_uint4(0u, 0u, 0u, 0u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// P R E D I C A T E P A C K I N G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE };
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// G E N E R I C P R E D I C A T E D L D G S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M, typename Functor >
inline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) {
// The number of complete bytes (where we use all the predicates in a byte).
enum { COMPLETE = N / PREDS_PER_BYTE };
// Make sure we did allocate enough predicates.
static_assert(Div_up<COMPLETE, BYTES_PER_REG>::VALUE <= M, "");
// The remainder.
enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE };
// Make sure we got the math right and the remainder is between 0 and 3.
static_assert(REMAINDER >= 0 && REMAINDER <= 3, "");
// The mask to extract the predicates.
enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 };
// Clear the fetch registers.
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
fct.clear(ii);
}
// Run complete steps.
bool p[PREDS_PER_BYTE];
#pragma unroll
for( int ii = 0; ii < COMPLETE; ++ii ) {
// The predicate.
uint32_t reg = preds[ii / BYTES_PER_REG];
// Extract the predicates.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj);
p[jj] = (reg & mask) != 0u;
}
// Issue the loads.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
fct.load(ii * PREDS_PER_BYTE + jj, p[jj]);
}
}
// Skip the rest of the code if we do not have a remainder.
if( REMAINDER > 0 ) {
// The mask to extract the predicates.
enum { REMAINDER_MASK = (1 << REMAINDER) - 1 };
// The predicate register.
uint32_t reg = preds[COMPLETE / BYTES_PER_REG];
// Extract the predicates.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj);
p[jj] = (reg & mask) != 0u;
}
// Issue the loads.
#pragma unroll
for( int ii = 0; ii < REMAINDER; ++ii ) {
fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, typename Functor >
inline __device__ void load_(Functor &fct, uint32_t preds) {
uint32_t tmp[1] = { preds };
load_<M>(fct, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint8_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint8_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint16_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint16_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint32_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint32_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint2 &dst, const void *ptr) {
dst = *reinterpret_cast<const uint2*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint4 &dst, const void *ptr) {
dst = *reinterpret_cast<const uint4*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type, int N >
struct Ldg_functor {
// Ctor.
inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N])
: fetch_(fetch), ptrs_(ptrs) {
}
// Clear the element.
inline __device__ void clear(int ii) {
fmha::clear(fetch_[ii]);
}
// Trigger the loads.
inline __device__ void load(int ii, bool p) {
if( p ) {
ldg(fetch_[ii], ptrs_[ii]);
}
}
// The fetch registers.
Data_type (&fetch_)[N];
// The pointers.
const void* (&ptrs_)[N];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type, int N, int M >
inline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
Ldg_functor<Data_type, N> fct(fetch, ptrs);
load_<N>(fct, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint8_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint16_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint32_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint2, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint4, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint16_t &dst, uint32_t ptr) {
asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint32_t &dst, uint32_t ptr) {
asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint2 &dst, uint32_t ptr) {
asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint4 &dst, uint32_t ptr) {
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x)
, "=r"(dst.y)
, "=r"(dst.z)
, "=r"(dst.w)
: "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S M
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n"
: "=r"(dst) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint32_t &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n"
: "=r"(dst) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint2 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint2 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint4 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint4 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint8_t val) {
*reinterpret_cast<uint8_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint16_t val) {
*reinterpret_cast<uint16_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint32_t val) {
*reinterpret_cast<uint32_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint2 val) {
*reinterpret_cast<uint2*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint4 val) {
*reinterpret_cast<uint4*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint16_t val) {
asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint32_t val) {
asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint2 val) {
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n"
:
: "r"(ptr)
, "r"(val.x)
, "r"(val.y));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint4 val) {
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n"
:
: "r"(ptr)
, "r"(val.x)
, "r"(val.y)
, "r"(val.z)
, "r"(val.w));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type, int N >
inline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) {
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
sts(ptrs[ii], data[ii]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) {
sts_<uint16_t, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) {
sts_<uint32_t, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) {
sts_<uint2, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {
sts_<uint4, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(__half2 (&dst)[M], __half2 (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_reduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(__half2 (&dst)[M], float2 (&src)[M], Operator &op) {
__half2 tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(reinterpret_cast<const __half2 &>(src[mi].x),
reinterpret_cast<const __half2 &>(src[mi].y));
}
quad_reduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = Allreduce<4>::run(dst[mi], op);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(__half2 (&dst)[M], __half2 (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = Allreduce<4>::run(dst[mi], op);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_allreduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(__half2 (&dst)[M], float2 (&src)[M], Operator &op) {
__half2 tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(reinterpret_cast<const __half2 &>(src[mi].x),
reinterpret_cast<const __half2 &>(src[mi].y));
}
quad_allreduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/* Copyright (c) 2022, Tri Dao.
*/
#include "fmha.h"
#include "fmha_block_dgrad_kernel_1xN_loop.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
__global__ void fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
fmha::compute_block_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
}
template<typename Kernel_traits>
void run_fmha_block_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2);
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
static_assert(smem_size_dp_sum == 16 * 4 * 2);
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum;
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
bool is_causal = params.is_causal;
auto kernel = is_dropout
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false>)
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false>);
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.seqlen_k == blocksize_c) {
kernel = is_dropout
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/1> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/1>)
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/1> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/1>);
} else if (params.seqlen_k == blocksize_c * 2) {
kernel = is_dropout
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/2> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/2>)
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
}
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
dim3 grid(params.b, params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) {
if (params.d == 16) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (params.d == 32) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>;
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (params.d == 64) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment