Commit 1fcbe6f0 authored by Tri Dao's avatar Tri Dao
Browse files

First release

parents
/* Copyright (c) 2022, Tri Dao.
*/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
}
template<typename Kernel_traits>
void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2);
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
static_assert(smem_size_dp_sum == 16 * 4 * 2);
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum;
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
bool is_causal = params.is_causal;
auto kernel = is_dropout
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false>)
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false>);
constexpr int N = Kernel_traits::Cta_tile_p::N;
if (params.s == N) {
kernel = is_dropout
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/1> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/1>)
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/1> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/1>);
} else if (params.s == N * 2) {
kernel = is_dropout
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/2>)
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
}
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
dim3 grid(params.h, params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
if (params.d == 16) {
if( params.s == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.s == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else {
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
} else if (params.d == 32) {
if( params.s == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.s >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
} else if (params.d == 64) {
if( params.s == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.s >= 256 ) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we should either use N=128 as the base, or keep V in registers.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
} else if (params.d == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
}
\ No newline at end of file
/* Copyright (c) 2022, Tri Dao.
*/
#pragma once
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Smem_dp_sum, int M>
inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const uint4 (&o)[M],
Smem_dp_sum smem, const int buffer_idx) {
#pragma unroll
for (int mi = 0; mi < M; ++mi) {
sum[mi] = smem.reduce_warp(fmha::hmulsum8(do_[mi], o[mi]));
}
static_assert(M == 1);
smem.store(sum[0], buffer_idx);
// smem.store(sum, buffer_idx);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng &ph,
const int loop_step_idx) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dq = typename Kernel_traits::Cta_tile_o;
// The description of the CTA tile for the 3rd batched GEMM.
using Cta_tile_dkv =
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dkv::M == 512 || Cta_tile_dkv::M == 256 || Cta_tile_dkv::M == 128);
static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128);
static_assert(Cta_tile_dkv::K == 16);
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dq = fmha::Hmma_tile<Cta_tile_dq>;
// The MMA tile for the 3rd GEMM.
using Mma_tile_dkv = fmha::Hmma_tile<Cta_tile_dkv>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to reload Q transposed.
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dkv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K^T. Treat K^T as V
using Smem_tile_kt = typename Kernel_traits::Smem_tile_v;
// Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load dO.
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
// The shared memory tile to load dO.
// Treating dO as Q.
using Smem_tile_do = typename Kernel_traits::Smem_tile_q;
// The shared memory tile to reload dO transposed.
using Smem_tile_dot = fmha::Smem_tile_b<Cta_tile_dkv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The global memory tile to load O.Loading O here is similar to loading dO.
using Gmem_tile_o = Gmem_tile_do;
// The global memory tile to store dQ.
// using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq;
using Gmem_tile_dq = fmha::Gmem_tile_dq<Cta_tile_dq>;
using Gmem_tile_dq_tmp = fmha::Gmem_tile_o<Cta_tile_dq, 4>;
// The shared memory tile to swizzle dQ.
using Smem_tile_dq = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store dV.
using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle dV.
using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dkv>;
// The global memory tile to store dK.
using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle dK.
using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dkv>;
static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);
static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
using Smem_dp_sum = typename Kernel_traits::Smem_dp_sum;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
// Shared memory.
extern __shared__ char smem_[];
// Shared memory layout if we keep V in registers:
// dO | Q | K / V | dQ | S | dP | dP_sum
// dV | dK
// Shared memory layout if we keep V shared memory:
// dO | Q | K | V | dQ | S | dP | dP_sum
// dV | dK
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
// if( binfo.stop_early() ) return;
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq gmem_dq(params, 0, binfo, tidx);
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!!
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params, binfo, tidx);
// Allocate the shared memory tile loader for dO.
Smem_tile_do smem_do(&smem_[0], tidx);
Smem_tile_dot smem_dot(&smem_[0], tidx);
// Allocate the shared memory tile loader for Q^T.
// TODO: assert that this points to the same memory as gemm_q_k.smem_q
Smem_tile_qt smem_qt(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
Smem_tile_st smem_s(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], tidx);
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
const int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
// constexpr int steps = Cta_tile_p::N / Cta_tile_p::M;
const int steps = params.s / Cta_tile_p::M - begin;
// Wind gmem tiles to the correct position.
gmem_q.move(begin);
gmem_do.move(begin);
gmem_o.move(begin);
gmem_dq.move(begin);
gmem_dq_tmp.move(begin);
// TODO: need to move gmem_s if we want the intermediate result for debugging
gmem_softmax_lse.move(begin);
gmem_softmax_d.move(begin);
if (!Is_first) {
gmem_k.move(loop_step_idx);
gmem_v.move(loop_step_idx);
}
// Trigger the loads for K.
gmem_k.load();
// Trigger the loads for Q.
gmem_q.load();
// Trigger the loads for V.
gmem_v.load();
// Trigger the loads for dO.
gmem_do.load();
// Trigger the loads for O.
if (Is_first) { gmem_o.load(); }
float p_lse[Mma_tile_p::MMAS_M * 2];
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
gmem_softmax_lse.move();
float dp_sum[Mma_tile_p::MMAS_M * 2];
if (!Is_first) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
}
float dp_sum_regs[Gmem_tile_do::LDGS];
Smem_dp_sum smem_dp_sum(reinterpret_cast<float *>(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE * 2]), tidx);
if (!Is_first) { __syncthreads(); }
// Commit the data for Q, dO, and V to shared memory.
gmem_q.commit(gemm_q_k.smem_q);
gmem_do.commit(smem_do);
if (Is_first) {
dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, 0);
const int dp_sum_row = tidx / Smem_dp_sum::THREADS_PER_ROW;
if ((dp_sum_row < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) {
gmem_softmax_d.store_row(reinterpret_cast<uint32_t(&)[Gmem_tile_do::LDGS]>(dp_sum_regs), dp_sum_row);
}
gmem_softmax_d.move();
}
// Instead of scaling dP by rp_dropout, we scale V instead
if (Is_dropout) {
const uint32_t scale_dropout = params.scale_dropout;
#pragma unroll
for(int it=0; it < Gmem_tile_v::LDGS; it++){
gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]);
}
}
gmem_v.commit(smem_v);
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
// #pragma unroll
// for(int it=0; it < Gmem_tile_k::LDGS; it++){
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
// }
// Commit the data for K to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_k.commit(gemm_q_k.smem_k);
}
__syncthreads();
// Load the fragments for Q.
gemm_q_k.load_q();
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Kernel_traits::V_IN_REGS ? Mma_tile_p::MMAS_K : 2][Mma_tile_p::MMAS_N];
if (Kernel_traits::V_IN_REGS) {
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_k.commit(gemm_q_k.smem_k);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for K.
gemm_q_k.load_k();
// Load the fragments for K^T.
// typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
// smem_kt.load(frag_kt[0], 0);
// typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N];
// #pragma unroll
// for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) {
// smem_kt.load(frag_kt[ki], ki);
// }
// Create the object to do the softmax.
// We won't be using the shared memory for this softmax at all
Softmax softmax(params, smem_, tidx);
// Declare the accumulators for the 3rd gemm.
fmha::Fragment_accumulator acc_dv[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dv);
fmha::Fragment_accumulator acc_dk[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length.
for( int l = 0; l < steps; l++ ) {
const int loop = (begin + l) * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen )
break;
// Load the fragments for V.
// typename Smem_tile_v::Fragment frag_v[2][Mma_tile_p::MMAS_N];
if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[0], 0); }
// Load the fragments for dO.
typename Smem_tile_do::Fragment frag_do[2][Mma_tile_p::MMAS_M];
smem_do.load(frag_do[0], 0);
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
gemm_q_k(acc_p);
// Load the mask for that iteration.
mask.load(begin + l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack_noscale(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
// Scale by log-sum-exp of the softmax
// softmax.apply_exp(p_lse);
softmax.template scale_apply_exp</*scale_max=*/false>(p_lse, params.scale_bmm1f);
if (Is_dropout) {
// softmax.apply_dropout(ph, params.p_dropout_in_uint);
// softmax.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint);
softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t);
}
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N);
softmax.pack(frag_p);
// Store s * dmask to smem for transpose
smem_s.store(frag_p);
// Trigger the load for the next Q values.
if( l < steps - 1) {
gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load();
}
// if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
// __syncthreads();
// }
// TD [2022-04-24]: if Is_first, then it's faster to set acc_dp to zero then subtract by
// dp_sum later. If !Is_first, then it's faster to set acc_dp to -dp_sum and don't subtract
// later. This is because loading dp_sum earlier uses more registers.
fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
if (Is_first) {
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_dp);
} else {
#pragma unroll
for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
#pragma unroll
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
#pragma unroll
for (int ii = 0; ii < 8; ++ii) {
acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)];
}
}
}
}
// Do this part of dP^T = (dO * V^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of dO values.
smem_do.load(frag_do[ki & 1], ki);
if (!Kernel_traits::V_IN_REGS) {
smem_v.load(frag_v[ki & 1], ki);
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
// printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y);
// tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[(ki - 1) & 1]));
// printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y);
// }
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
if (!Kernel_traits::V_IN_REGS) {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
}
}
// Load the fragments for K^T.
typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
smem_kt.load(frag_kt[0], 0);
if (Is_first) {
const int quad = (tidx % Cta_tile_p::THREADS_PER_WARP) / 4;
const int row[2] = {quad, quad + 8};
smem_dp_sum.load(dp_sum, row, l % 2);
}
// Trigger the load for the next dO values.
if( l < steps - 1) {
smem_do.move_to_next_write_buffer();
gmem_do.move();
gmem_do.load();
if (Is_first) {
gmem_o.move();
gmem_o.load();
}
}
softmax.unpack_noscale(acc_dp);
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// // will be zero.
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; }
if (Is_first) { softmax.subtract_dp_sum(dp_sum); }
Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
softmax.pack(frag_dp);
if (!Is_dropout) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
frag_p[mi][ni].hmul(frag_dp[mi][ni]);
}
}
} else {
__half2 dp_sum_half[Mma_tile_p::MMAS_M * 2];
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
dp_sum_half[mi] = __float2half2_rn(dp_sum[mi]);
}
const __half zero_h = __half(0.f);
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
#pragma unroll
for (int ii = 0; ii < 4; ++ii) {
const __half2 p = frag_p[mi][ni].template elt_as<__half2>(ii);
const __half2 pdp = __hmul2(p, frag_dp[mi][ni].template elt_as<__half2>(ii));
// If this element is dropped, then frag_p stores -p instead of p.
// So pd holds -p * dp_sum in that case.
const __half2 pd = __hmul2(p, dp_sum_half[mi * 2 + (ii % 2)]);
const __half low = __low2half(p) >= zero_h ? __low2half(pdp) : __low2half(pd);
const __half high = __high2half(p) >= zero_h ? __high2half(pdp) : __high2half(pd);
frag_p[mi][ni].template elt_as<__half2>(ii) = __halves2half2(low, high);
}
}
}
}
// Store dp to smem for transpose
smem_dp.store(frag_p);
// gmem_s.store(frag_p, mask);
// gmem_s.move();
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_dq::WARPS_K>::apply(acc_dq);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_dq::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_kt.load(frag_kt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dq::MMAS_K;
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
static_assert(Gmem_tile_dq::LOOPS == 1);
// Swizzle the elements and do the final reduction.
smem_dq.store(acc_dq, 0);
typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N];
static_assert(Smem_tile_dot::Fragment::NUM_REGS == 4);
static_assert(Mma_tile_dkv::MMAS_K == 1);
smem_dot.load(frag_dot[0], 0);
// Threads in a warp is communicating via shared memory (smem_s and smem_dp)
__syncwarp();
typename Smem_tile_st::Fragment frag_s[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
smem_s.load(frag_s);
if (Is_dropout) {
#pragma unroll
for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
frag_s[ki][mi].hrelu_();
}
}
}
#pragma unroll
for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_dot.load(frag_dot[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
}
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if(l < steps - 1) {
gmem_q.commit(gemm_q_k.smem_q);
}
uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP];
if (!Is_first) { gmem_dq_tmp.load(dq_out, 0); }
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if(l < steps - 1) {
gmem_do.commit(smem_do);
if (Is_first) {
// dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum);
// smem_dp_sum.move_to_next_write_buffer();
dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, (l + 1) % 2);
const int dp_sum_row_1 = tidx / Smem_dp_sum::THREADS_PER_ROW;
if ((dp_sum_row_1 < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) {
gmem_softmax_d.store_row(reinterpret_cast<uint32_t(&)[Gmem_tile_do::LDGS]>(dp_sum_regs), dp_sum_row_1);
}
gmem_softmax_d.move();
}
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
gmem_softmax_lse.move();
if (!Is_first) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
}
}
typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
smem_dp.load(frag_dpt);
gemm_q_k.reload_k();
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dkv::MMAS_N];
static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);
static_assert(Mma_tile_dkv::MMAS_K == 1);
smem_qt.load(frag_qt[0], 0);
#pragma unroll
for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Make sure dQ is in shared memory.
__syncthreads();
// Load from shared memory.
smem_dq.template load</*zero_init=*/Is_first>(dq_out);
const bool is_final_write =
Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen)
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
if (is_final_write) {
// if (Is_dropout) {
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
// }
dq_out[0] = fmha::fmul4(dq_out[0], params.scale_bmm1f);
// Output the values.
gmem_dq.store(dq_out, 0);
// Move to the next part of the output.
gmem_dq.move();
} else {
// Output the values.
gmem_dq_tmp.store(dq_out, 0);
}
// Move to the next part of the output.
if (!(Is_first && Is_last)) { gmem_dq_tmp.move(); }
// // Make sure the data is in shared memory.
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if(l < steps - 1) {
gemm_q_k.smem_q.move_to_next_read_buffer();
gemm_q_k.reload_q();
smem_qt.move_to_next_read_buffer();
// smem_qt.load(frag_qt[0], 0);
smem_do.move_to_next_read_buffer();
smem_dot.move_to_next_read_buffer();
// smem_dot.load(frag_dot[0], 0);
}
} // Outer loop over the sequence length.
if (Is_dropout) {
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
acc_dv[mi][ni].mul_(params.rp_dropout);
}
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1));
// }
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
// acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f);
acc_dk[mi][ni].mul_(params.scale_bmm1f);
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1));
// }
__syncthreads();
// TODO [TD - 2022-05-04]: Are there cases where the shared mem for dV and dK are larger than
// the total amount of shared mem?
// Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[0], tidx);
smem_dv.store(acc_dv);
// Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx);
smem_dk.store(acc_dk);
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out);
Qkv_params dv_params;
dv_params.qkv_ptr = params.dqkv_ptr;
dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dv_params.h = params.h;
Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);
if (!Is_first) {
gmem_dv.move(loop_step_idx);
}
gmem_dv.store(dv_out);
uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out);
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// }
Qkv_params dk_params;
dk_params.qkv_ptr = params.dqkv_ptr;
dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dk_params.h = params.h;
Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);
if (!Is_first) {
gmem_dk.move(loop_step_idx);
}
gmem_dk.store(dk_out);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// loop_steps = -1 means the number of steps will be params.s / Kernel_traits::Cta_tile_p::N.
// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2.
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1, typename Params>
inline __device__ void compute_dq_dk_dv_1xN(const Params &params) {
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const int tidx_global = (bidb * params.h + bidh) * blockDim.x + tidx;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
if (loop_steps == 1) {
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
} else if (loop_steps == 2) {
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, 1);
} else {
if (params.s == N_per_loop) {
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
} else {
const int max_loop_steps = (params.s + N_per_loop - 1) / N_per_loop;
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false>(params, ph, loop_step_idx);
}
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, max_loop_steps - 1);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/* Copyright (c) 2022, Tri Dao.
*/
#pragma once
#include "fmha_fprop_kernel_1xN.h"
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMAS_M>
inline __device__ void dot_fragments(float (&sum)[MMAS_M * 2],
const fmha::Fragment_a<fmha::Row> (&x)[MMAS_M],
const fmha::Fragment_a<fmha::Row> (&y)[MMAS_M]) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
sum[mi * 2 + 0] += hfma2_to_float(x[mi].template elt_as<__half2>(0),
y[mi].template elt_as<__half2>(0));
sum[mi * 2 + 0] += hfma2_to_float(x[mi].template elt_as<__half2>(2),
y[mi].template elt_as<__half2>(2));
sum[mi * 2 + 1] += hfma2_to_float(x[mi].template elt_as<__half2>(1),
y[mi].template elt_as<__half2>(1));
sum[mi * 2 + 1] += hfma2_to_float(x[mi].template elt_as<__half2>(3),
y[mi].template elt_as<__half2>(3));
// hfma2_to_float(sum[mi * 2 + 0], x[mi].template elt_as<__half2>(0), y[mi].template elt_as<__half2>(0));
// hfma2_to_float(sum[mi * 2 + 0], x[mi].template elt_as<__half2>(2), y[mi].template elt_as<__half2>(2));
// hfma2_to_float(sum[mi * 2 + 1], x[mi].template elt_as<__half2>(1), y[mi].template elt_as<__half2>(1));
// hfma2_to_float(sum[mi * 2 + 1], x[mi].template elt_as<__half2>(3), y[mi].template elt_as<__half2>(3));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename Params>
inline __device__ void compute_dp_dq_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dq = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dq = fmha::Hmma_tile<Cta_tile_dq>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K^T. Treat K^T as V
using Smem_tile_kt = typename Kernel_traits::Smem_tile_v;
// Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load dO.
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
// The shared memory tile to load dO.
// Treating dO as Q.
using Smem_tile_do = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load O.Loading O here is similar to loading dO.
using Gmem_tile_o = Gmem_tile_do;
// The shared memory tile to load O.
using Smem_tile_o = Smem_tile_do;
// The global memory tile to store dQ.
// using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq;
using Gmem_tile_dq = fmha::Gmem_tile_dq<Cta_tile_dq>;
// The shared memory tile to swizzle dQ.
using Smem_tile_dq = typename Kernel_traits::Smem_tile_o;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() ) return;
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the global memory tile loader for dQ.
Gmem_tile_dq gmem_dq(params, 0, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
fmha::Mask<Cta_tile_p> mask(binfo, tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!!
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params, binfo, tidx);
// Allocate the shared memory tile loader for dO.
Smem_tile_do smem_do(&smem_[0], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params, binfo, tidx);
// Allocate the shared memory tile loader for O.
Smem_tile_o smem_o(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
// Trigger the loads for K.
gmem_k.load();
// Trigger the loads for Q.
gmem_q.load();
// Trigger the loads for V.
gmem_v.load();
// Trigger the loads for dO.
gmem_do.load();
// Trigger the loads for O.
gmem_o.load();
const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
#pragma unroll
for(int it=0; it < Gmem_tile_k::LDGS; it++){
gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
}
// Commit the data for Q, dO, and V to shared memory.
gmem_q.commit(gemm_q_k.smem_q);
gmem_do.commit(smem_do);
gmem_o.commit(smem_o);
gmem_v.commit(smem_v);
// Commit the data for K to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_k.commit(gemm_q_k.smem_k);
}
__syncthreads();
// Load the fragments for Q.
gemm_q_k.load_q();
// Load the fragments for dO.
typename Smem_tile_do::Fragment frag_do[2][Mma_tile_p::MMAS_M];
smem_do.load(frag_do[0], 0);
// Load the fragments for O.
typename Smem_tile_o::Fragment frag_o[2][Mma_tile_p::MMAS_M];
smem_o.load(frag_o[0], 0);
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_k.commit(gemm_q_k.smem_k);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for K.
gemm_q_k.load_k();
// Load the fragments for K^T.
typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
smem_kt.load(frag_kt[0], 0);
// typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N];
// #pragma unroll
// for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) {
// smem_kt.load(frag_kt[ki], ki);
// }
// Create the object to do the softmax.
// We won't be using the shared memory for this softmax at all
// Softmax softmax(params, &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], bidb, tidx);
Softmax softmax(params, smem_, tidx);
// Softmax softmax_dp(params, &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], bidb, tidx);
Gmem_softmax_sum gmem_softmax_sum(params.softmax_lse_ptr, params, tidx);
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
constexpr int STEPS = Cta_tile_p::N / Cta_tile_p::M;
// Load over the entire sequence length.
for( int l = 0; l < STEPS; l++ ) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen )
break;
float p_lse[Mma_tile_p::MMAS_M * 2];
gmem_softmax_sum.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
gmem_softmax_sum.move();
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
gemm_q_k(acc_p);
// Trigger the load for the next Q values.
if( l < STEPS - 1) {
gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load();
}
// Load the mask for that iteration.
mask.load(l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack_noscale(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
// Scale by log-sum-exp of the softmax
softmax.template apply_exp</*max_in_base2=*/true>(p_lse);
// softmax.unpack_noscale_half_and_apply_mask(acc_p, mask);
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N);
softmax.pack(frag_p);
// if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
// __syncthreads();
// }
float dp_sum_new[Mma_tile_p::MMAS_M * 2] = {0};
dot_fragments(dp_sum_new, frag_do[0], frag_o[0]);
fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_dp);
// Do this part of dP^T = (dO * V^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of dO values.
smem_do.load(frag_do[ki & 1], ki);
smem_o.load(frag_o[ki & 1], ki);
dot_fragments(dp_sum_new, frag_do[ki & 1], frag_o[ki & 1]);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("dp_sum_new=%.6f, %.6f\n", dp_sum_new[0], dp_sum_new[1]);
// }
// smem_v.load(frag_v[ki & 1], ki);
// Do the math for the values already in registers.
// fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
// if ((threadIdx.x == 1) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
// printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y);
// tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki - 1]));
// printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y);
// }
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
// fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
}
// if ((threadIdx.x == 1) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("acc_dp=%.6f, %.6f\n", acc_dp[0][0].elt(0), acc_dp[0][0].elt(1));
// }
// Trigger the load for the next dO values.
if( l < STEPS - 1) {
smem_do.move_to_next_write_buffer();
gmem_do.move();
gmem_do.load();
smem_o.move_to_next_write_buffer();
gmem_o.move();
gmem_o.load();
}
// softmax_dp.unpack_noscale(acc_dp);
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// // will be zero.
// #pragma unroll
// for( int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++ ) {
// #pragma unroll
// for( int ni = 0; ni < Mma_tile_p::MMAS_N * 4; ni++ ) {
// softmax_dp.elt_[mi][ni] *= softmax.elt_[mi][ni];
// }
// }
// float dp_sum[Mma_tile_p::MMAS_M * 2];
// softmax_dp.reduce_sum(dp_sum);
// gmem_softmax_d.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
// gmem_softmax_d.move();
// #pragma unroll
// for( int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++ ) {
// #pragma unroll
// for( int ni = 0; ni < Mma_tile_p::MMAS_N * 4; ni++ ) {
// softmax_dp.elt_[mi][ni] -= dp_sum[mi] * softmax.elt_[mi][ni];
// }
// }
fmha::SumOp<float> sum_op;
fmha::quad_allreduce(dp_sum_new, dp_sum_new, sum_op);
// softmax_dp.unpack_noscale(acc_dp);
softmax.unpack_noscale(acc_dp);
// #pragma unroll
// for( int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++ ) {
// #pragma unroll
// for( int ni = 0; ni < Mma_tile_p::MMAS_N * 4; ni++ ) {
// // softmax_dp.elt_[mi][ni] -= dp_sum_new[mi];
// softmax.elt_[mi][ni] -= dp_sum_new[mi];
// }
// }
softmax.subtract_dp_sum(dp_sum_new);
Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
// softmax_dp.pack(frag_dp);
softmax.pack(frag_dp);
gmem_softmax_d.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum_new));
gmem_softmax_d.move();
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
frag_p[mi][ni].hmul(frag_dp[mi][ni]);
}
}
// softmax_dp.pack(frag_p);
// gmem_s.store(frag_p, mask);
// gmem_s.move();
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if(l < STEPS - 1) {
gmem_q.commit(gemm_q_k.smem_q);
gmem_do.commit(smem_do);
gmem_o.commit(smem_o);
}
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_dq::WARPS_K>::apply(acc_dq);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_dq::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_kt.load(frag_kt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dq::MMAS_K;
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_dq::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_dq.store(acc_dq, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_dq::STGS_PER_LOOP];
smem_dq.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_dq::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_dq.store(out, ii);
}
// Move to the next part of the output.
gmem_dq.move();
gemm_q_k.reload_k();
smem_kt.load(frag_kt[0], 0);
// // Make sure the data is in shared memory.
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if(l < STEPS - 1) {
gemm_q_k.smem_q.move_to_next_read_buffer();
gemm_q_k.reload_q();
smem_do.move_to_next_read_buffer();
smem_do.load(frag_do[0], 0);
smem_o.move_to_next_read_buffer();
smem_o.load(frag_o[0], 0);
}
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename Params>
inline __device__ void compute_dv_dk_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dkv = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dk = fmha::Hmma_tile<Cta_tile_dkv>;
// The global memory tile to load Q. Treating Q as K.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_k;
// The global memory tile to load K. Treating K as Q.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q^T. Treat Q^T as V
using Smem_tile_qt = typename Kernel_traits::Smem_tile_v;
// Treating V as dO.
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_q;
// Treating dO as V in dQ kernel, which is the same as K in the forward kernel.
// The global memory tile to load dO.
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_dot;
// The shared memory tile to load dO.
using Smem_tile_do = typename Kernel_traits::Smem_tile_k;
// The shared memory tile to load dO^T.
using Smem_tile_dot = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store dK and dV.
// using Gmem_tile_dkv = typename Kernel_traits::Gmem_tile_dkv;
using Gmem_tile_dkv = fmha::Gmem_tile_dq<Cta_tile_dkv>;
// The shared memory tile to swizzle dK and dV.
using Smem_tile_dkv = typename Kernel_traits::Smem_tile_o;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() ) return;
Gemm1 gemm_q_k(&smem_[Smem_tile_v::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the global memory tile loader for dK.
Gmem_tile_dkv gmem_dk(params, 1, binfo, tidx);
// Allocate the global memory tile loader for dV.
Gmem_tile_dkv gmem_dv(params, 2, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
fmha::Mask<Cta_tile_p> mask(binfo, tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// Allocate the shared memory tile loader for dO.
Smem_tile_v smem_v(&smem_[0], tidx);
// Allocate the shared memory tile loader for Q^T. We use the same as Q so be careful!!!
Smem_tile_qt smem_qt(&smem_[Smem_tile_v::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params, binfo, tidx);
// The base pointer of smem_do;
char *smem_do_ = &smem_[Smem_tile_v::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_do smem_do(smem_do_, tidx);
Smem_tile_dot smem_dot(smem_do_, tidx);
// Allocate the shared memory tile loader for dK and dV. We use the same as K so be careful!!!
Smem_tile_dkv smem_dkv(&smem_[Smem_tile_v::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
// Trigger the loads for Q.
gmem_q.load();
// Trigger the loads for K.
gmem_k.load();
// Trigger the loads for dO.
gmem_do.load();
// Trigger the loads for V.
gmem_v.load();
const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
#pragma unroll
for(int it=0; it < Gmem_tile_q::LDGS; it++){
gmem_q.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_q.fetch_[it]);
}
// Commit the data for K, dO, and V to shared memory.
gmem_k.commit(gemm_q_k.smem_q);
gmem_v.commit(smem_v);
gmem_do.commit(smem_do);
// Commit the data for Q to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_q.commit(gemm_q_k.smem_k);
}
__syncthreads();
// Load the fragments for K.
gemm_q_k.load_q();
// Load the fragments for V.
typename Smem_tile_v::Fragment frag_v[2][Mma_tile_p::MMAS_M];
smem_v.load(frag_v[0], 0);
// Load the fragments for dO. We keep the data in registers during the entire kernel.
typename Smem_tile_do::Fragment frag_do[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_dk::MMAS_K; ++ki ) {
smem_do.load(frag_do[ki], ki);
}
using Smem_tile_mma_t = fmha::Smem_tile_transpose<Cta_tile_p>;
// Smem_tile_mma_t smem_mmat(&smem_[Smem_tile_v::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
Smem_tile_mma_t smem_mmat(&smem_[Smem_tile_v::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dkv::BYTES_PER_TILE], tidx);
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_q.commit(gemm_q_k.smem_k);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for Q.
gemm_q_k.load_k();
// Load the fragments for K^T.
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N];
smem_qt.load(frag_qt[0], 0);
// typename Smem_tile_qt::Fragment frag_qt[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_N];
// #pragma unroll
// for( int ki = 0; ki < Mma_tile_dk::MMAS_K; ++ki ) {
// smem_qt.load(frag_qt[ki], ki);
// }
// Create the object to do the softmax.
// We won't be using the shared memory for either of the softmax at all
Softmax softmax(params, smem_, tidx);
Softmax softmax_dp(params, smem_, tidx);
Gmem_softmax_sum gmem_softmax_sum(params.softmax_lse_ptr, params, tidx);
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
int warp = tidx / Cta_tile_p::THREADS_PER_WARP;
int lane = tidx % Cta_tile_p::THREADS_PER_WARP;
int rows[Mma_tile_p::MMAS_N * 4];
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ni++) {
rows[ni * 4 + 0] = ni * Cta_tile_p::WARPS_N * 16 + warp * 16 + (lane % 4) * 2;
rows[ni * 4 + 1] = ni * Cta_tile_p::WARPS_N * 16 + warp * 16 + (lane % 4) * 2 + 1;
rows[ni * 4 + 2] = ni * Cta_tile_p::WARPS_N * 16 + warp * 16 + (lane % 4) * 2 + 8;
rows[ni * 4 + 3] = ni * Cta_tile_p::WARPS_N * 16 + warp * 16 + (lane % 4) * 2 + 9;
}
float p_lse[Mma_tile_p::MMAS_N * 4];
gmem_softmax_sum.load_row(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_N * 4]>(p_lse), rows);
float dp_sum[Mma_tile_p::MMAS_N * 4];
gmem_softmax_d.load_row(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_N * 4]>(dp_sum), rows);
// int qid = lane / 8;
// int rows_shfl[Mma_tile_p::MMAS_N];
// for (int ni = 0; ni < Mma_tile_p::MMAS_N; ni++) {
// rows_shfl[ni] = ni * Cta_tile_p::WARPS_N * 16 + warp * 16 + (lane % 4) * 2 + (qid / 2) * 8 + (qid % 2);
// }
// float p_lse[Mma_tile_p::MMAS_N];
// gmem_softmax_sum.load_row(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_N]>(p_lse), rows_shfl);
// float dp_sum[Mma_tile_p::MMAS_N];
// gmem_softmax_d.load_row(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_N]>(dp_sum), rows_shfl);
constexpr int STEPS = Cta_tile_p::N / Cta_tile_p::M;
// Load over the entire sequence length.
for( int l = 0; l < STEPS; l++ ) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen )
break;
typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_p::MMAS_N];
// smem_mmat.store(frag_do, 0);
// smem_mmat.load(frag_dot[0]);
// smem_mmat.transpose(frag_do, frag_dot[0], 0);
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_pt[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_pt);
// Do this part of P^T = (Q * K^T)^T.
gemm_q_k(acc_pt);
// Trigger the load for the next K values.
if( l < STEPS - 1) {
gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_k.move();
gmem_k.load();
}
// Load the mask for that iteration.
mask.load(l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack_noscale(acc_pt);
// Apply the mask.
softmax.apply_mask(mask);
// Scale by log-sum-exp of the softmax
softmax.template apply_exp_col</*max_in_base2=*/true>(p_lse);
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];
softmax.pack(frag_p);
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_dv[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dv);
smem_mmat.transpose(frag_do, frag_dot[0], 0);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_dk::MMAS_K; ++ki ) {
// fmha::gemm(acc_dv, frag_p[ki], frag_dot[ki]);
if (ki + 1 < Mma_tile_dk::MMAS_K) {
// smem_mmat.store(frag_do, ki + 1);
// smem_mmat.load(frag_dot[(ki + 1) % 2]);
smem_mmat.transpose(frag_do, frag_dot[(ki + 1) % 2], ki + 1);
}
fmha::gemm(acc_dv, frag_p[ki], frag_dot[ki % 2]);
}
__syncthreads();
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_dkv::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_dkv.store(acc_dv, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_dkv::STGS_PER_LOOP];
smem_dkv.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_dkv::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_dv.store(out, ii);
}
// Move to the next part of the output.
gmem_dv.move();
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
fmha::Fragment_accumulator acc_dpt[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_dpt);
// Do this part of dP^T = (dO * V^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of dO values.
smem_v.load(frag_v[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dpt, frag_v[(ki - 1) & 1], frag_do[ki - 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
// fmha::gemm(acc_dpt, frag_v[(ki - 1) & 1], frag_do[(ki - 1) & 1]);
fmha::gemm(acc_dpt, frag_v[(ki - 1) & 1], frag_do[(ki - 1)]);
}
// Trigger the load for the next V values.
if( l < STEPS - 1) {
smem_v.move_to_next_write_buffer();
gmem_v.move();
gmem_v.load();
}
softmax_dp.unpack_noscale(acc_dpt);
// TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// will be zero.
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N * 4; ni++ ) {
// softmax.elt_[mi][ni] *= (softmax_dp.elt_[mi][ni] - dp_sum[ni]);
softmax_dp.elt_[mi][ni] -= dp_sum[ni];
// const float tmp = __shfl_sync(0xffffffff, dp_sum[ni / 4], (ni % 4) * 8 + threadIdx.x % 8);
// softmax_dp.elt_[mi][ni] -= tmp;
}
}
Frag_p frag_dp[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];
softmax_dp.pack(frag_dp);
#pragma unroll
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
frag_p[mi][ni].hmul(frag_dp[mi][ni]);
}
}
// using Frag_p = fmha::Fragment_a<fmha::Row>;
// Frag_p frag_p[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];
// softmax.pack(frag_p);
// softmax_dp.pack(frag_p);
// gmem_s.store(frag_p, mask);
// gmem_s.move();
__syncthreads();
// Commit the values for K and V into shared memory.
if(l < STEPS - 1) {
gmem_k.commit(gemm_q_k.smem_q);
gmem_v.commit(smem_v);
}
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dk);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dk, frag_p[ki - 1], frag_qt[(ki - 1) & 1]);
// fmha::gemm(acc_dk, frag_p[ki - 1], frag_qt[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_dk, frag_p[ki - 1], frag_qt[(ki - 1) & 1]);
// fmha::gemm(acc_dk, frag_p[ki - 1], frag_qt[(ki - 1)]);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_dkv::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_dkv.store(acc_dk, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_dkv::STGS_PER_LOOP];
smem_dkv.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_dkv::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_dk.store(out, ii);
}
// Move to the next part of the output.
gmem_dk.move();
gemm_q_k.reload_k();
smem_qt.load(frag_qt[0], 0);
// Make sure the data is in shared memory.
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if(l < STEPS - 1) {
gemm_q_k.smem_q.move_to_next_read_buffer();
gemm_q_k.reload_q();
smem_v.move_to_next_read_buffer();
smem_v.load(frag_v[0], 0);
}
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
__global__ void fmha_fprop_fp16_sm80_loop_kernel(Fused_multihead_attention_fprop_params params) {
fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
}
template<typename Kernel_traits>
void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params,
const bool configure) {
bool is_causal = launch_params.params.is_causal;
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
auto kernel = launch_params.is_dropout
? (is_causal
? (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, false>)
: (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, false>))
: (is_causal
? (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, false>)
: (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, false>));
constexpr int N = Kernel_traits::Cta_tile_p::N;
const int loop_steps = (launch_params.params.s + N - 1) / N;
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
// Don't need smem_size_softmax_lse if we're not looping
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
if (configure) {
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
constexpr int M = Kernel_traits::Cta_tile_p::M;
size_t STEPS = (launch_params.params.s + M - 1) / M;
constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
launch_params.elts_per_thread = elts_per_head;
return;
}
dim3 grid(launch_params.params.h, launch_params.params.b);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params,
const bool configure) {
if (launch_params.params.d == 16) {
if( launch_params.params.s == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.s == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} else if (launch_params.params.d == 32) {
if( launch_params.params.s == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.s == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} else if (launch_params.params.d == 64) {
if( launch_params.params.s == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.s == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
} else if (launch_params.params.d == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
// if (launch_params.params.d == 64) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
}
\ No newline at end of file
/***************************************************************************************************
* Copyright (c) 2022, Tri Dao.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
#include <fmha/utils.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits>
struct Gemm_Q_K_base {
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
using Fragment_q = typename Smem_tile_q::Fragment;
using Fragment_k = typename Smem_tile_k::Fragment;
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2;
__device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx)
: smem_q(smem_ptr_q, tidx)
, smem_k(smem_ptr_k, tidx) {
}
__device__ inline void load_q() {
smem_q.load(frag_q[0], 0);
}
__device__ inline void reload_q() {
smem_q.load(frag_q[0], 0);
}
Fragment_q frag_q[2][Mma_tile_p::MMAS_M];
Smem_tile_q smem_q;
Smem_tile_k smem_k;
};
template<typename Kernel_traits, bool K_in_regs>
struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>;
using Smem_tile_o = typename Base::Smem_tile_o;
using Smem_tile_q = typename Base::Smem_tile_q;
using Smem_tile_k = typename Base::Smem_tile_k;
using Fragment_k = typename Base::Fragment_k;
using Mma_tile_p = typename Base::Mma_tile_p;
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
// If V is stored in shared memory, we can't load K using the same shared memory.
static_assert(Kernel_traits::V_IN_REGS);
static constexpr int SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE;
static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE;
static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE);
// Q | K / V
// | O | SOFTMAX
static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE
+ std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE,
Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX);
__device__ inline Gemm_Q_K(char * smem_, const int tidx)
: Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {
}
__device__ inline void load_k(){
#pragma unroll
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
Base::smem_k.load(frag_k[ki], ki);
}
}
template<typename Acc, int M, int N>
__device__ inline void operator()(Acc (&acc_p)[M][N]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
Base::smem_q.load(Base::frag_q[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
}
__device__ inline void reload_k(){
// Noop.
}
Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
};
template<typename Kernel_traits>
struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>;
using Smem_tile_o = typename Base::Smem_tile_o;
using Smem_tile_q = typename Base::Smem_tile_q;
using Smem_tile_k = typename Base::Smem_tile_k;
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
using Fragment_k = typename Base::Fragment_k;
using Mma_tile_p = typename Base::Mma_tile_p;
Fragment_k frag_k[2][Mma_tile_p::MMAS_N];
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS;
static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V);
static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE);
static_assert(Smem_tile_v::BYTES_PER_TILE == (int) Smem_tile_k::BYTES_PER_TILE);
static constexpr int SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE;
static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE;
// If V_IN_REGS and SHARE_SMEM_FOR_K_AND_V: Q | K/V | O | SOFTMAX
// If !V_IN_REGS (then !SHARE_SMEM_FOR_K_AND_V): Q | K | V | O | SOFTMAX
static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE
+ (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE
+ Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX;
__device__ inline Gemm_Q_K(char * smem_, const int tidx)
: Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {
}
__device__ inline void load_k(){
Base::smem_k.load(frag_k[0], 0);
}
template<typename Acc, int M, int N>
__device__ inline void operator()(Acc (&acc_p)[M][N]){
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
Base::smem_q.load(Base::frag_q[ki & 1], ki);
Base::smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
}
__device__ inline void reload_k(){
Base::smem_k.load(frag_k[0], 0);
}
};
template<typename Kernel_traits>
constexpr size_t get_dynamic_smem_size(){
return Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>::SMEM_BYTES;
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
using Gmem_tile_o_tmp = fmha::Gmem_tile_o<Cta_tile_o, 4>;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum;
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
// if( binfo.stop_early() ) return;
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
// Wind gmem tiles to the correct position.
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
const int begin_og = begin;
begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin;
const int steps_og = steps;
steps -= begin - begin_og;
gmem_q.move(begin);
gmem_o.move(begin);
gmem_o_tmp.move(begin);
if (Return_softmax) { gmem_s.move(begin); }
gmem_softmax_lse.move(begin);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("begin = %d, steps = %d\n", begin, steps);
// }
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params, 2, binfo, tidx);
// The base pointer of smem_v;
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v smem_v(smem_v_, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx);
if (!Is_first) {
gmem_k.move(loop_step_idx);
gmem_v.move(loop_step_idx);
if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); }
}
// Trigger the loads for K.
gmem_k.load();
// Trigger the loads for Q.
gmem_q.load();
// Trigger the loads for V.
gmem_v.load();
if (!Is_first) { __syncthreads(); }
float p_prev_lse[Mma_tile_p::MMAS_M * 2];
if (!Is_first) {
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
}
// Commit the data for Q and V to shared memory.
gmem_q.commit(gemm_q_k.smem_q);
gmem_v.commit(smem_v);
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
// #pragma unroll
// for(int it=0;it < Gmem_tile_k::LDGS;it++){
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
// }
// Commit the data for K to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
gmem_k.commit(gemm_q_k.smem_k);
}
__syncthreads();
// Load the fragments for Q.
gemm_q_k.load_q();
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
smem_v.load(frag_v[ki], ki);
}
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for V.
gmem_k.commit(gemm_q_k.smem_k);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for K.
gemm_q_k.load_k();
// Create the object to do the softmax.
Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx);
Smem_softmax_sum smem_softmax_lse(reinterpret_cast<float *>(&smem_[Gemm1::SMEM_BYTES]), tidx);
// Load over the entire sequence length.
for( int l = 0; l < steps; l++ ) {
if((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen) break;
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P = Q * K^T.
gemm_q_k(acc_p);
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
if (!Is_first) { gmem_o_tmp.load(out, 0); }
// Trigger the load for the next Q values.
if( l < steps - 1) {
gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load();
}
// Load the mask for that iteration.
mask.load(begin + l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack_noscale(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
// softmax.unpack_noscale_half_and_apply_mask(acc_p, mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]);
// }
// }
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
if (!Is_first) {
smem_softmax_lse.store_pair(p_prev_lse, l % 2);
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; }
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; }
}
// Trigger the load for the next LSE values.
if( l < steps - 1) {
if (!Is_first) {
gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
}
}
// __half2 p_max[Mma_tile_p::MMAS_M];
softmax.template reduce_max</*zero_init=*/Is_first>(p_max);
// if ((threadIdx.x == 0) && (l == 38)) {
// printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]);
// }
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
// }
// }
// Compute the exponential value.
// softmax.apply_exp(p_max);
softmax.scale_apply_exp(p_max, params.scale_bmm1f);
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
// }
// }
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
// if (!Is_first) {
// int warp = tidx / Cta_tile_p::THREADS_PER_WARP;
// int lane = tidx % Cta_tile_p::THREADS_PER_WARP;
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
// p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? expf(p_prev_lse[mi] - p_max[mi]) : 0;
// }
// }
// softmax.reduce_sum(p_sum);
softmax.reduce_sum_before_sync_(p_sum);
// softmax.template reduce_sum_before_sync_</*zero_init=*/Is_first>(p_sum);
// float p_sum_log[Mma_tile_p::MMAS_M * 2];
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) {
// float sum = p_sum[mi];
// // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + __logf(sum);
// constexpr float kLog2e = M_LOG2E;
// p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum);
// }
// // gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum));
// gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum_log));
// gmem_softmax_lse.move();
// // Finalize softmax on the accumulators of P^T.
// softmax.scale(p_sum);
constexpr bool encode_dropout_in_sign_bit = Return_softmax;
if (Is_dropout) {
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph0, params.p_dropout_in_uint);
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint);
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint16_t);
}
using Frag_p = fmha::Fragment_a<fmha::Row>;
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
softmax.pack(frag_p);
if (Return_softmax) {
gmem_s.store(frag_p, mask);
gmem_s.move();
}
// Commit the values for Q into shared memory.
if(l < steps - 1) {
gmem_q.commit(gemm_q_k.smem_q);
}
if (Is_dropout && encode_dropout_in_sign_bit) {
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
frag_p[ki][mi].hrelu_();
}
}
}
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);
}
// The mapping from tidx to rows changes between the softmax and the O-reduction.
// So we recalculate the max.
float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
// TODO: not sure if this is right for seqlen 128 or 256
int rows[Gmem_tile_o::STGS_PER_LOOP];
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
}
softmax.reduce_max_after_sync_(p_max_o, rows);
static_assert(Mma_tile_o::MMAS_M == 1);
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
p_max_o[jj][0] *= params.scale_bmm1f;
}
float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP];
if (!Is_first) { smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); }
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]);
// }
// }
static_assert(Gmem_tile_o::LOOPS == 1);
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, 0);
// Make sure the data is in shared memory.
__syncthreads();
static_assert(Mma_tile_o::MMAS_M == 1);
float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
softmax.reduce_sum_after_sync_(p_sum_o, rows);
if (!Is_first) {
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]);
p_sum_o[jj][0] += p_prev_scale_o[jj];
}
}
float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
#pragma unroll
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
float sum = p_sum_o[jj][0];
p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum);
// if (sum == 0.f || sum != sum) {
// printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]);
// }
// if (Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// printf("p_sum_log=%.6f\n", p_sum_log[jj][0]);
// }
// }
if ((tidx % Gmem_tile_o::THREADS_PER_ROW == 0) && (tidx / Gmem_tile_o::THREADS_PER_ROW < Gmem_tile_o::ROWS)) {
gmem_softmax_lse.store_row(
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
}
}
gmem_softmax_lse.move();
// Load from shared memory.
if (!Is_first) {
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]);
}
}
smem_o.template load</*zero_init=*/Is_first>(out);
const bool is_final_write =
Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen)
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
#pragma unroll
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
float sum = p_sum_o[jj][0];
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
if (Is_dropout && is_final_write) {
inv_sum *= params.rp_dropout;
}
out[jj] = fmha::fmul4(out[jj], inv_sum);
}
// if (Is_dropout && Is_last) {
// for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
// out[jj] = fmha::fmul4(out[jj], params.rp_dropout);
// }
// }
// Output the values.
if (is_final_write) {
gmem_o.store(out, 0);
gmem_o.move();
} else {
gmem_o_tmp.store(out, 0);
}
// Move to the next part of the output.
if (!(Is_first && Is_last)) { gmem_o_tmp.move(); }
gemm_q_k.reload_k();
// Make sure we are reading from the correct buffer.
gemm_q_k.smem_q.move_to_next_read_buffer();
// Trigger the load from shared memory for the next series of Q values.
if(l < steps - 1) {
gemm_q_k.reload_q();
}
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, typename Params>
inline __device__ void device_1xN_loop(const Params &params) {
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds));
const int STEPS = params.s / Kernel_traits::Cta_tile_p::M;
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N;
if (params.s == N_per_loop) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0);
} else {
const int max_loop_steps = (params.s + N_per_loop - 1) / N_per_loop;
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx);
}
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, 0, STEPS, ph0, ph1, max_loop_steps - 1);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <philox.cuh>
#include <fmha.h>
#include <fmha/utils.h>
#include <fmha/smem_tile.h>
#include <fmha/gmem_tile.h>
#include <fmha/mask.h>
#include <fmha/softmax.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS_PER_CTA>
struct BlockInfoPadded {
template<typename Params>
__device__ BlockInfoPadded(const Params &params,
const int bidb,
const int bidh,
const int tidx)
: bidb(bidb), bidh(bidh), h(params.h) {
// The block index.
sum_s = params.cu_seqlens[bidb];
actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s;
bidx = sum_s * params.h + bidh;
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
}
__device__ bool stop_early(const int start_col = 0) const {
return actual_seqlen <= start_col;
}
int actual_seqlen;
int bidx;
int sum_s;
int bidh;
int bidb;
int tidx_global;
int h;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int CHUNKS, typename Cta_tile>
struct Noloop_traits{
// Interpretation of Cta_tile dims, i.e. Cta_tile_p:
enum{ STEP = Cta_tile::M };
enum{ SEQLEN = Cta_tile::N };
template<typename Block_info>
inline __device__ Noloop_traits(const int bidc, const Block_info& binfo)
: bidc_(bidc) {
const int seqlen = binfo.actual_seqlen;
const int steps = (seqlen + STEP - 1) / STEP;
const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS;
const int step_begin = bidc_ * steps_per_chunk;
const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk);
const int actual_steps = max(0, step_end - step_begin);
loop_offset_ = step_begin;
num_steps_ = actual_steps;
}
template<typename ... Tiles>
inline __device__ void move_all(Tiles & ... tiles) const {
using expand_type = int[];
for( int s = 0; s < loop_offset_; s++ ) {
expand_type{ (tiles.move(), 0)... };
}
}
inline __device__ int get_idx_dk() const {
//return bidc_;
return bidc_ * 2 + 0;
}
inline __device__ int get_idx_dv() const {
//return CHUNKS + bidc_;
return bidc_ * 2 + 1;
}
inline __device__ int offset_loop_count(const int l) {
// convert loop counter to position in the outer sequence
return (loop_offset_ + l) * STEP;
}
const uint32_t bidc_;
int loop_offset_;
int num_steps_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits>
std::tuple<int , int, int, int, int, int> work_dist(const int total_ctas, const int heads_total) {
constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;
const int num_full_heads = heads_total / total_ctas;
const int heads_last_wave = heads_total % total_ctas;
int num_main_groups = 0;
int main_steps = 0;
int rest_steps = 0;
if( heads_last_wave > 0 ) {
// Number of CTA groups that process within heads.
num_main_groups = total_ctas / heads_last_wave;
// Remaining CTAs that process between heads.
const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups);
if(rest_ctas == 0) {
// We have exactly "num_main_groups" CTAs to process each of the remaining heads.
main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups;
num_main_groups = STEPS_PER_HEAD / main_steps; // Here: main_step > 0
rest_steps = STEPS_PER_HEAD % main_steps;
} else {
// Ideal number of steps if we could load-balance as evenly as possible.
const int steps_ideal = (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas;
// Iterations that a "rest" CTA has to do at most.
const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas;
// Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs.
main_steps = steps_ideal;
rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
for( ; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++ ) {
rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
const int max_rest_total_steps = rest_steps * max_rest_iters;
if( max_rest_total_steps < main_steps )
break;
}
rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
}
}
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
const int max_steps = STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps);
const int elts_per_thread_per_step = Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8;
const int elts_per_thread = max_steps * elts_per_thread_per_step;
return {num_full_heads, num_main_groups, heads_last_wave, main_steps, rest_steps, elts_per_thread};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define FMHA_CHECK_CUDA( call ) \
do { \
cudaError_t status_ = call; \
if( status_ != cudaSuccess ) { \
fprintf( stderr, \
"CUDA error (%s:%d): %s\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString( status_ ) ); \
exit( 1 ); \
} \
} while( 0 )
////////////////////////////////////////////////////////////////////////////////////////////////////
enum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {
if( dtype == DATA_TYPE_FP16 ) {
half x = __float2half_rn( norm );
uint16_t h = reinterpret_cast<const uint16_t &>( x );
ushort2 h2 = { h, h };
alpha = reinterpret_cast<const uint32_t &>( h2 );
} else if( dtype == DATA_TYPE_FP32 ) {
alpha = reinterpret_cast<const uint32_t &>( norm );
} else if( dtype == DATA_TYPE_INT32 ) {
int32_t inorm = static_cast<int32_t>( norm );
alpha = reinterpret_cast<const uint32_t &>( inorm );
} else {
assert( false );
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) {
switch( dtype ) {
case DATA_TYPE_FP32:
return n * 4;
case DATA_TYPE_FP16:
return n * 2;
case DATA_TYPE_INT32:
return n * 4;
case DATA_TYPE_INT8:
return n;
default:
assert( false );
return 0;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu
#pragma once
// Philox CUDA.
namespace {
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset)
: STATE(0)
, key(reinterpret_cast<const uint2&>(seed)) {
//key.x = (unsigned int)seed;
//key.y = (unsigned int)(seed >> 32);
//counter = make_uint4(0, 0, 0, 0);
//counter.z = (unsigned int)(subsequence);
//counter.w = (unsigned int)(subsequence >> 32);
//STATE = 0;
//incr_n(offset / 4);
// key = reinterpret_cast<const uint2&>(seed);
ull2 * tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset / 4;
tmp->y = subsequence;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
// }
}
__device__ inline uint4 operator()() {
// if (STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
// 7-round philox
#pragma unroll
for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
}
// output = single_round(counter_, key_);
uint4 output = single_round(counter_, key_);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
// }
incr();
// }
// return a float4 directly
// unsigned long ret;
// switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
// STATE = (STATE + 1) % 4;
return output;
}
private:
struct ull2 {
uint64_t x;
uint64_t y;
};
uint4 counter;
// uint4 output;
const uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ uint4 incr128 (uint4 ctr)
{
uint4 res;
asm ("add.cc.u32 %0, %4, %8;\n\t"
"addc.cc.u32 %1, %5, %9;\n\t"
"addc.cc.u32 %2, %6, %10;\n\t"
"addc.u32 %3, %7, %11;\n\t"
: "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w)
: "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w),
"n"(1), "n"(0), "n"(0), "n"(0));
return res;
}
__device__ inline void incr() {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
counter = incr128(counter);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
}
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a * b;
}
__device__ uint2 mulhilo32_v2 (const unsigned int a, const unsigned int b)
{
uint2 *res;
unsigned long long tmp;
asm ("mul.wide.u32 %0, %1, %2;\n\t"
: "=l"(tmp)
: "r"(a), "r"(b));
res = (uint2*)(&tmp);
return *res;
}
__device__ inline uint4 single_round(const uint4 ctr, const uint2 key) {
//unsigned int hi0;
//unsigned int hi1;
//unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
//unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
//uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
uint2 res0 = mulhilo32_v2(kPhiloxSA, ctr.x);
uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z);
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
// Inverse of 2^32.
constexpr float M_RAN_INVM32 = 2.3283064e-10f;
__device__ __inline__ float4 uniform4(const uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32);
}
} // namespace
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
# We split the input differently ((d 2) -> d 2 instead of (2 d) -> d 2), following the original
# paper's implementation. This should not matter.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
# NOTE: Almost the same right now, moving parts to Triton is the next step
from typing import Tuple
import math
import torch
from einops import rearrange, repeat
def rotate_half(x):
# rearrange doesn't work with torch.jit
# x = rearrange(x, '... (d r) -> ... d r', r=2)
x = x.unflatten(dim=-1, sizes=(-1, 2))
x1, x2 = x.unbind(dim=-1)
rotated_x = torch.stack((-x2, x1), dim=-1)
# return rearrange(rotated_x, '... d r -> ... (d r)')
return rotated_x.flatten(start_dim=-2)
@torch.jit.script
def apply_rotary_pos_emb(x, cos, sin, seq_dimension: int = -2):
# NOTE: This could probably be moved to Triton
# Handle a possible sequence length mismatch in between q and k
cos = cos[:x.shape[seq_dimension], :]
sin = sin[:x.shape[seq_dimension], :]
if seq_dimension == -3:
cos = cos[:, None, :]
sin = sin[:, None, :]
return (x * cos) + (rotate_half(x) * sin)
class RotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
"""
def __init__(self, dim_model: int, *_, **__):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x, seq_dimension=-2):
seq_len = x.shape[seq_dimension]
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (seq_len != self._seq_len_cached or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype):
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
self._cos_cached = repeat(torch.cos(freqs).to(x.dtype), '... d -> ... (d 2)')
self._sin_cached = repeat(torch.sin(freqs).to(x.dtype), '... d -> ... (d 2)')
return self._cos_cached, self._sin_cached
def forward(self, q: torch.Tensor, k: torch.Tensor,
seq_dimension=-2) -> Tuple[torch.Tensor, torch.Tensor]:
assert seq_dimension in [-2, -3] # Either (bs, h, s, d) or (bs, s, h, d)
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
k, seq_dimension=seq_dimension
)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dimension),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dimension),
)
class RotaryEmbedding2D(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
assert dim % 4 == 0
self.rotary_emb1d = RotaryEmbedding(dim // 2)
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dimension=-2):
assert seq_dimension in [-2, -3] # Either (bs, h, s, d) or (bs, s, h, d)
seqlen = q.shape[seq_dimension]
seqlen_sqrt = int(math.sqrt(seqlen))
assert seqlen == seqlen_sqrt ** 2
if seq_dimension == -3: # (bs, s, h, d)
q = rearrange(q, 'b s h d -> b h s d')
k = rearrange(k, 'b s h d -> b h s d')
q0, q1 = q.chunk(2, dim=-1)
k0, k1 = k.chunk(2, dim=-1)
# (bs, h, s, d)
q0 = rearrange(q0, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
k0 = rearrange(k0, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
q0_emb, k0_emb = self.rotary_emb1d(q0, k0, seq_dimension=-2)
q0_emb = rearrange(q0_emb, 'b nheads h w d -> b nheads (h w) d')
k0_emb = rearrange(k0_emb, 'b nheads h w d -> b nheads (h w) d')
q1 = rearrange(q1, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
k1 = rearrange(k1, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
q1_emb, k1_emb = self.rotary_emb1d(q1, k1, seq_dimension=-3)
q1_emb = rearrange(q1_emb, 'b nheads h w d -> b nheads (h w) d')
k1_emb = rearrange(k1_emb, 'b nheads h w d -> b nheads (h w) d')
q_emb, k_emb = torch.cat([q0_emb, q1_emb], dim=-1), torch.cat([k0_emb, k1_emb], dim=-1)
if seq_dimension == -3:
q_emb = rearrange(q_emb, 'b h s d -> b s h d')
k_emb = rearrange(k_emb, 'b h s d -> b s h d')
return q_emb, k_emb
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import torch
import torch.nn as nn
import stream_attn_cuda
def _stream_attn_forward(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal, return_softmax):
context, softmax_lse, *rest = stream_attn_cuda.fwd(qkv, cu_seqlens, dropout_p, max_s, softmax_scale,
False, causal, return_softmax, None)
# if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask = rest[0] if return_softmax else None
return context, softmax_lse, S_dmask
def _stream_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p, max_s,
softmax_scale, causal):
dqkv, dp, softmax_d = stream_attn_cuda.bwd(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p,
softmax_scale, max_s, False, causal, None)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dqkv
class StreamAttnFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _stream_attn_forward(
qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=False
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_s = max_s
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return context
@staticmethod
def backward(ctx, dout):
qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
# S_dmask is None, temporarily use another tensor just to get it running
dqkv = _stream_attn_backward(
dout, qkv, context, context, softmax_lse, cu_seqlens, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class StreamAttnFunWithS(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass is gonna regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _stream_attn_forward(
qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=True
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_s = max_s
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return context, S_dmask, softmax_lse
@staticmethod
def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = _stream_attn_backward(
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None
def stream_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
"""
func = StreamAttnFun if not return_attn_probs else StreamAttnFunWithS
return func.apply(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal)
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import torch
import torch.nn as nn
import stream_attn_cuda
def convert_blockmask(blockmask, causal):
"""Convert from the 0-1 format to the format used by the CUDA code.
0 means the block is skipped.
nonzero means the block is not skipped.
Argument:
blockmask: (row, col): a 0-1 tensor
Return:
blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
indices of the nonzero blocks, padded with -1 to reach length @row.
The indices are multiplied by 4, with the smallest bit used to encode whether
it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
the last nonzero in its row..
"""
assert not causal
# TD [2022-05-13]: The indexing and sorting is very tricky
nrow, ncol = blockmask.shape
# Sort does not support bool on CUDA
blockmask = blockmask.to(dtype=torch.uint8)
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
]
first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
]
nonzero_idx = nonzero_sorted_rowidx * 4
nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
nonzero_idx[nonzero_val == 0] = -1
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
def _stream_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale,
causal, return_softmax):
context, softmax_lse, *rest = stream_attn_cuda.fwd_block(qkv, cu_seqlens, blockmask, dropout_p,
max_s, softmax_scale, causal,
return_softmax, None)
# if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask = rest[0] if return_softmax else None
return context, softmax_lse, S_dmask
def _stream_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask,
dropout_p, max_s, softmax_scale, causal):
dqkv, dp, softmax_d = stream_attn_cuda.bwd_block(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens,
blockmask, dropout_p, softmax_scale, max_s,
causal, None)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dqkv
class StreamBlocksparseAttnFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _stream_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
return_softmax=False
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
ctx.dropout_p = dropout_p
ctx.max_s = max_s
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return context
@staticmethod
def backward(ctx, dout):
qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
# S_dmask is None, temporarily use another tensor just to get it running
dqkv = _stream_blocksparse_attn_backward(
dout, qkv, context, context, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None, None
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class StreamBlocksparseAttnFunWithS(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass is gonna regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _stream_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
return_softmax=True
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
ctx.dropout_p = dropout_p
ctx.max_s = max_s
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return context, S_dmask, softmax_lse
@staticmethod
def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = _stream_blocksparse_attn_backward(
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
ctx.max_s, ctx.softmax_scale, ctx.causal
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None
def stream_blocksparse_attn_func(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale=None,
causal=False, return_attn_probs=False, convert_mask=True):
"""dropout_p should be set to 0.0 during evaluation
"""
func = StreamBlocksparseAttnFun if not return_attn_probs else StreamBlocksparseAttnFunWithS
if convert_mask:
blockmask = convert_blockmask(blockmask, causal=causal)
return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal)
import math
import torch
import torch.nn as nn
from einops import rearrange
from rotary import RotaryEmbedding, RotaryEmbedding2D
from stream_attn_interface import stream_attn_func
from bert_padding import unpad_input, pad_input, index_first_axis
class StreamingAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
"""
def __init__(self, softmax_temp=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__()
self.softmax_temp = softmax_temp
self.dropout_p = attention_dropout
def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
key_padding_mask: An implementation of BaseMask that encodes how
many query each sequence in the batch consists of
"""
assert not need_weights
assert attn_mask is None
assert qkv.dtype == torch.float16
assert qkv.is_cuda
if cu_seqlens is None:
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = stream_attn_func(qkv, cu_seqlens, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else:
key_padding_mask_bool = key_padding_mask.bool_matrix
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = stream_attn_func(x_unpad, cu_seqlens,
self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads)
else:
assert max_s is not None
output = stream_attn_func(qkv, cu_seqlens,
self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal)
return output, None
class StreamingMHA(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
causal=False, use_rotary_emb=None, device=None, dtype=None, **kwargs) -> None:
assert batch_first
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"
assert use_rotary_emb in [None, '1d', '2d']
self.use_rotary_emb = use_rotary_emb
if self.use_rotary_emb == '1d':
self.rotary_emb = RotaryEmbedding(self.head_dim)
elif self.use_rotary_emb == '2d':
self.rotary_emb = RotaryEmbedding2D(self.head_dim)
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.inner_attn = StreamingAttention(attention_dropout=attention_dropout, **factory_kwargs)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
need_weights=False):
qkv = self.Wqkv(x)
if self.use_rotary_emb:
query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,
h=self.num_heads).unbind(dim=2)
query, key = self.rotary_emb(query, key, seq_dimension=-3)
qkv = torch.stack([query, key, value], dim=2)
else:
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
need_weights=need_weights, causal=self.causal)
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights
import math
import torch
import torch.nn as nn
from einops import rearrange
import hydra
from stream_blocksparse_attn_interface import stream_blocksparse_attn_func
from stream_blocksparse_attn_interface import convert_blockmask
from bert_padding import unpad_input, pad_input, index_first_axis
class StreamingBlocksparseAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
"""
def __init__(self, sparsity_config, softmax_temp=None, attention_dropout=0.0,
max_seq_length=2048, device=None, dtype=None):
super().__init__()
self.sparsity_config = hydra.utils.instantiate(sparsity_config)
self.softmax_temp = softmax_temp
self.dropout_p = attention_dropout
# initialize sparse layout and register as buffer
max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256
layout = self.sparsity_config.make_layout(max_seq_length)
self.register_buffer("layout", layout)
blockmask_converted = convert_blockmask(self.layout, causal=False)
self.register_buffer("blockmask_converted", blockmask_converted)
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False, convert_mask=True):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
key_padding_mask: An implementation of BaseMask that encodes how
many query each sequence in the batch consists of
"""
assert not need_weights
assert attn_mask is None
assert qkv.dtype == torch.float16
assert qkv.is_cuda
if cu_seqlens is None:
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
# Convert mask to take a subset
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = stream_blocksparse_attn_func(
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else:
key_padding_mask_bool = key_padding_mask.bool_matrix
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = stream_blocksparse_attn_func(
x_unpad, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal
)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads)
else:
assert max_s is not None
seqlen = max_s
# Convert mask to take a subset
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
if convert_mask:
output = stream_blocksparse_attn_func(
qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal
)
else:
output = stream_blocksparse_attn_func(
qkv, cu_seqlens, self.blockmask_converted, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal,
convert_mask=False,
)
return output, None
class StreamingBlocksparseMHA(nn.Module):
def __init__(self, embed_dim, num_heads, sparsity_config, bias=True, batch_first=True,
attention_dropout=0.0, causal=False, max_seq_length=2048,
device=None, dtype=None, **kwargs) -> None:
assert batch_first
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.inner_attn = StreamingBlocksparseAttention(
sparsity_config, attention_dropout=attention_dropout,
max_seq_length=max_seq_length, **factory_kwargs
)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
need_weights=False):
qkv = self.Wqkv(x)
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
need_weights=need_weights, causal=self.causal)
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment