Commit 6c3a8c65 authored by Tri Dao's avatar Tri Dao
Browse files

Implement cross attention

parent 01947bc9
...@@ -32,10 +32,11 @@ Our tentative roadmap: ...@@ -32,10 +32,11 @@ Our tentative roadmap:
3. [Jun 2022] Refactor to use Cutlass. 3. [Jun 2022] Refactor to use Cutlass.
4. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done]. 4. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
5. [Jun 2022] Support bf16. 5. [Jun 2022] Support bf16.
6. ~~[Jul 2022] Support head dimension 128~~[Done]. 6. ~~[Jul 2022] Implement cross-attention~~[Done].
7. [Jul 2022] Support SM70 GPUs (V100). 7. ~~[Jul 2022] Support head dimension 128~~[Done].
8. [Aug 2022] Fuse rotary embedding. 8. [Jul 2022] Support SM70 GPUs (V100).
9. [Aug 2022] Support Attention linear bias (e.g. ALiBi). 9. [Aug 2022] Fuse rotary embedding.
10. [Aug 2022] Support Attention linear bias (e.g. ALiBi).
## Speedup and Memory Savings ## Speedup and Memory Savings
......
...@@ -8,7 +8,7 @@ from einops import rearrange, repeat ...@@ -8,7 +8,7 @@ from einops import rearrange, repeat
from benchmarks.utils import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined from benchmarks.utils import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
from flash_attn.bert_padding import unpad_input, pad_input from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attn_interface import flash_attn_func from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False): def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False):
...@@ -62,7 +62,9 @@ qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3, ...@@ -62,7 +62,9 @@ qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3,
h=nheads).detach().requires_grad_() h=nheads).detach().requires_grad_()
qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_() qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_()
fn = lambda qkv_unpad: flash_attn_func(qkv_unpad, cu_seqlens, dropout_p, max_seqlen_in_batch, causal=causal) fn = lambda qkv_unpad: flash_attn_unpadded_qkvpacked_func(
qkv_unpad, cu_seqlens, max_seqlen_in_batch, dropout_p, causal=causal
)
benchmark_all(fn, qkv_unpad, repeats=repeats, desc='FlashAttention') benchmark_all(fn, qkv_unpad, repeats=repeats, desc='FlashAttention')
fn = lambda qkv: attention_ref(qkv, attention_mask_bool, dropout_p, causal=causal) fn = lambda qkv: attention_ref(qkv, attention_mask_bool, dropout_p, causal=causal)
benchmark_all(fn, qkv, repeats=repeats, desc='PyTorch Standard Attention') benchmark_all(fn, qkv, repeats=repeats, desc='PyTorch Standard Attention')
This diff is collapsed.
...@@ -42,9 +42,8 @@ ...@@ -42,9 +42,8 @@
constexpr int TOTAL_DIM = 0; constexpr int TOTAL_DIM = 0;
constexpr int THREE_DIM = 1; constexpr int H_DIM = 1;
constexpr int H_DIM = 2; constexpr int D_DIM = 2;
constexpr int D_DIM = 3;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -72,10 +71,7 @@ struct Qkv_params { ...@@ -72,10 +71,7 @@ struct Qkv_params {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fused_multihead_attention_fprop_params : public Qkv_params { struct FMHA_fprop_params : public Qkv_params {
// The dQKV matrices.
void * __restrict__ dqkv_ptr;
// The O matrix (output). // The O matrix (output).
void * __restrict__ o_ptr; void * __restrict__ o_ptr;
...@@ -90,10 +86,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { ...@@ -90,10 +86,7 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// the loop; // the loop;
void *__restrict__ o_tmp_ptr; void *__restrict__ o_tmp_ptr;
// The dO matrix . // The pointer to the S matrix.
void * __restrict__ do_ptr;
// The pointer to the S matrix, overwritten by the dP matrix (bwd).
void * __restrict__ s_ptr; void * __restrict__ s_ptr;
// The stride between rows of the S matrix. // The stride between rows of the S matrix.
// int64_t s_stride_in_bytes; // int64_t s_stride_in_bytes;
...@@ -102,18 +95,16 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { ...@@ -102,18 +95,16 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The pointer to the softmax sum. // The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr; void * __restrict__ softmax_lse_ptr;
// The pointer to the softmax d sum.
void * __restrict__ dsoftmax_sum;
// The dimensions. // The dimensions.
int b, s, d; int b, seqlen_q, seqlen_k, d, seqlen_q_rounded;
// The scaling factors for the kernel. // The scaling factors for the kernel.
float scale_bmm1f; float scale_bmm1f;
uint32_t scale_bmm1, scale_softmax, scale_bmm2; uint32_t scale_bmm1;
// array of length b+1 holding starting offset of each sequence. // array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens; int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int *__restrict__ blockmask; int *__restrict__ blockmask;
...@@ -136,7 +127,33 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { ...@@ -136,7 +127,33 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_params> struct FMHA_dgrad_params : public FMHA_fprop_params {
// The dQKV matrices.
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
// The stride between rows of the dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
uint32_t dq_row_stride_in_elts;
uint32_t dk_row_stride_in_elts;
uint32_t dv_row_stride_in_elts;
uint32_t dq_head_stride_in_elts;
uint32_t dk_head_stride_in_elts;
uint32_t dv_head_stride_in_elts;
// The dO matrix. We assume it is contiguous.
void * __restrict__ do_ptr;
// The pointer to the softmax d sum.
void * __restrict__ dsoftmax_sum;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_params>
struct Launch_params{ struct Launch_params{
Launch_params(cudaDeviceProp * props_, Launch_params(cudaDeviceProp * props_,
cudaStream_t stream_, cudaStream_t stream_,
...@@ -168,10 +185,10 @@ struct Launch_params{ ...@@ -168,10 +185,10 @@ struct Launch_params{
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure); void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream); void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream);
void run_fmha_block_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure); void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
void run_fmha_block_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream); void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream);
...@@ -63,9 +63,9 @@ struct Gmem_tile_qkv { ...@@ -63,9 +63,9 @@ struct Gmem_tile_qkv {
// Ctor. // Ctor.
template< typename BInfo > template< typename BInfo >
inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts, inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts,
const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx) const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx, bool use_seqlen_q)
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen(binfo.actual_seqlen) , actual_seqlen(use_seqlen_q ? binfo.actual_seqlen_q : binfo.actual_seqlen_k)
, ptr(reinterpret_cast<char *>(ptr_)) , ptr(reinterpret_cast<char *>(ptr_))
, tidx_(tidx) { , tidx_(tidx) {
...@@ -80,7 +80,7 @@ struct Gmem_tile_qkv { ...@@ -80,7 +80,7 @@ struct Gmem_tile_qkv {
// The row offset in the batched GEMM. For each seq element, we store QKV in that order. // The row offset in the batched GEMM. For each seq element, we store QKV in that order.
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; // int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
uint32_t row_offset = (uint32_t)((binfo.sum_s + row) * row_stride_in_bytes); uint32_t row_offset = (uint32_t)(((use_seqlen_q ? binfo.sum_s_q : binfo.sum_s_k) + row) * row_stride_in_bytes);
// Add the block index. // Add the block index.
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
...@@ -193,7 +193,7 @@ struct Gmem_tile_o { ...@@ -193,7 +193,7 @@ struct Gmem_tile_o {
inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts, inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts,
const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx) const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx)
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen(binfo.actual_seqlen) , actual_seqlen_q(binfo.actual_seqlen_q)
, ptr_(reinterpret_cast<char *>(ptr)) , ptr_(reinterpret_cast<char *>(ptr))
, tidx_(tidx) { , tidx_(tidx) {
...@@ -207,7 +207,7 @@ struct Gmem_tile_o { ...@@ -207,7 +207,7 @@ struct Gmem_tile_o {
// The row offset in the batched GEMM. // The row offset in the batched GEMM.
// int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW; // int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
uint32_t row_offset = (uint32_t)((binfo.sum_s + row) * row_stride_in_bytes); uint32_t row_offset = (uint32_t)((binfo.sum_s_q + row) * row_stride_in_bytes);
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
// Assemble the final pointer. // Assemble the final pointer.
ptr_ += row_offset + col * BYTES_PER_STG; ptr_ += row_offset + col * BYTES_PER_STG;
...@@ -224,7 +224,7 @@ struct Gmem_tile_o { ...@@ -224,7 +224,7 @@ struct Gmem_tile_o {
#pragma unroll #pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii; int jj = mi * STGS_PER_LOOP + ii;
if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen ) { if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q ) {
break; break;
} }
...@@ -252,7 +252,7 @@ struct Gmem_tile_o { ...@@ -252,7 +252,7 @@ struct Gmem_tile_o {
#pragma unroll #pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii; int jj = mi * STGS_PER_LOOP + ii;
if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen ) { if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q ) {
break; break;
} }
...@@ -266,7 +266,7 @@ struct Gmem_tile_o { ...@@ -266,7 +266,7 @@ struct Gmem_tile_o {
// row_ += ROWS * steps; // row_ += ROWS * steps;
// ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps; // ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps; ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps;
actual_seqlen -= ROWS * steps; actual_seqlen_q -= ROWS * steps;
} }
// The stride between rows for the QKV matrice. // The stride between rows for the QKV matrice.
...@@ -277,7 +277,7 @@ struct Gmem_tile_o { ...@@ -277,7 +277,7 @@ struct Gmem_tile_o {
// Is the thread active for the last STG? // Is the thread active for the last STG?
int is_active_for_last_stg_; int is_active_for_last_stg_;
// The length of the sequence loaded by that memory tile. // The length of the sequence loaded by that memory tile.
int actual_seqlen; int actual_seqlen_q;
const int tidx_; const int tidx_;
}; };
...@@ -319,8 +319,8 @@ struct Gmem_tile_mma_sd { ...@@ -319,8 +319,8 @@ struct Gmem_tile_mma_sd {
uint32_t bidx = bidb * params.h + bidh; uint32_t bidx = bidb * params.h + bidh;
// The distance between two blocks (in bytes). // The distance between two blocks (in bytes).
// const size_t block_stride_bytes = params.s * params.s * BYTES_PER_ELEMENT; // const size_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
const uint32_t block_stride_bytes = params.s * params.s * BYTES_PER_ELEMENT; const uint32_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
// Set store location for each thread at the beginning of the loop // Set store location for each thread at the beginning of the loop
ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG; ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG;
} }
...@@ -468,8 +468,8 @@ struct Gmem_summary_stats { ...@@ -468,8 +468,8 @@ struct Gmem_summary_stats {
int lane = tidx % Cta_tile::THREADS_PER_WARP; int lane = tidx % Cta_tile::THREADS_PER_WARP;
// The distance between two blocks (in bytes). // The distance between two blocks (in bytes).
// size_t block_stride_bytes = params.s * BYTES_PER_ELEMENT; // size_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
uint32_t block_stride_bytes = params.s * BYTES_PER_ELEMENT; uint32_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
// Set store location for each thread at the beginning of the loop // Set store location for each thread at the beginning of the loop
ptr_row_ = ptr_ + bidx * block_stride_bytes; ptr_row_ = ptr_ + bidx * block_stride_bytes;
......
...@@ -35,8 +35,8 @@ struct Mask { ...@@ -35,8 +35,8 @@ struct Mask {
using Mma_tile = fmha::Hmma_tile<Cta_tile>; using Mma_tile = fmha::Hmma_tile<Cta_tile>;
template<typename BInfo> template<typename BInfo>
__device__ Mask(const BInfo &blockInfo, int tidx, const int loop_step_idx_ = 0) __device__ Mask(const BInfo &binfo, int tidx, const int loop_step_idx_ = 0)
: actual_seqlen(blockInfo.actual_seqlen - loop_step_idx_ * Cta_tile::N) : actual_seqlen_k(binfo.actual_seqlen_k - loop_step_idx_ * Cta_tile::N)
, loop_step_idx(loop_step_idx_) { , loop_step_idx(loop_step_idx_) {
const int warp = tidx / Cta_tile::THREADS_PER_WARP; const int warp = tidx / Cta_tile::THREADS_PER_WARP;
...@@ -60,12 +60,11 @@ struct Mask { ...@@ -60,12 +60,11 @@ struct Mask {
// const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1); // const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1); const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
const int current_row = row_offset + ii * 8; const int current_row = row_offset + ii * 8;
const bool col_valid = current_col < actual_seqlen; const bool col_valid = current_col < actual_seqlen_k;
// const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen; // const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen; //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k;
bool all_valid = Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("current_col=%d, current_row=%d, actual_seqlen=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen, col_valid, all_valid); // printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid);
// } // }
return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid; return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
// return row_valid && col_valid; // return row_valid && col_valid;
...@@ -84,7 +83,7 @@ struct Mask { ...@@ -84,7 +83,7 @@ struct Mask {
int row; int row;
int col; int col;
const int loop_step_idx; const int loop_step_idx;
const int actual_seqlen; const int actual_seqlen_k;
}; };
} // namespace fmha } // namespace fmha
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
#include "fmha_block_dgrad_kernel_1xN_loop.h" #include "fmha_block_dgrad_kernel_1xN_loop.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
__global__ void fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(Fused_multihead_attention_fprop_params params) { __global__ void fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
fmha::compute_block_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params); fmha::compute_block_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) { void run_fmha_block_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
...@@ -30,12 +30,12 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_ ...@@ -30,12 +30,12 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_
auto kernel = is_dropout auto kernel = is_dropout
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false>) ? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false>)
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false>); : (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false>);
constexpr int N = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.s == N) { if (params.seqlen_k == blocksize_c) {
kernel = is_dropout kernel = is_dropout
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/1> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/1>) ? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/1> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/1>)
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/1> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/1>); : (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/1> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/1>);
} else if (params.s == N * 2) { } else if (params.seqlen_k == blocksize_c * 2) {
kernel = is_dropout kernel = is_dropout
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/2> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/2>) ? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/2> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/2>)
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>); : (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
...@@ -50,7 +50,7 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_ ...@@ -50,7 +50,7 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_
FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaPeekAtLastError());
} }
void run_fmha_block_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) { void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) {
if (params.d == 16) { if (params.d == 16) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
......
...@@ -138,9 +138,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -138,9 +138,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true);
// Allocate the global memory tile loader for dQ. // Allocate the global memory tile loader for dQ.
Gmem_tile_dq gmem_dq(params.dqkv_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); Gmem_tile_dq gmem_dq(params.dq_ptr, params.dq_row_stride_in_elts, params.dq_head_stride_in_elts, binfo, tidx);
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S. // Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
...@@ -148,9 +148,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -148,9 +148,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx); fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx, false);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx, false);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V]; char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
...@@ -160,7 +160,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -160,7 +160,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx); Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for dO. // Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx, true);
// Allocate the shared memory tile loader for dO. // Allocate the shared memory tile loader for dO.
Smem_tile_do smem_do(&smem_[0], tidx); Smem_tile_do smem_do(&smem_[0], tidx);
Smem_tile_dot smem_dot(&smem_[0], tidx); Smem_tile_dot smem_dot(&smem_[0], tidx);
...@@ -172,7 +172,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -172,7 +172,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx); Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O. // Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx, true);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!! // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx); Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
...@@ -181,7 +181,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -181,7 +181,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx); Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
const int steps = params.s / Cta_tile_p::M; const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M;
// Wind gmem tiles to the correct position. // Wind gmem tiles to the correct position.
int block_row_idx_next = mask_val / 4; int block_row_idx_next = mask_val / 4;
...@@ -316,7 +316,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -316,7 +316,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("block_row_idx = %d\n", block_row_idx); // printf("block_row_idx = %d\n", block_row_idx);
// } // }
if (block_row_idx * Cta_tile_p::M >= binfo.actual_seqlen) break; if (block_row_idx * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
int mask_val_next = l < steps - 1 ? blockmask.mask_val(l + 1) : -1; int mask_val_next = l < steps - 1 ? blockmask.mask_val(l + 1) : -1;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
...@@ -629,7 +629,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -629,7 +629,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
const bool is_final_write = const bool is_final_write =
Is_last Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen) || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|| ((mask_val & 0x2) != 0) || ((mask_val & 0x2) != 0)
|| ((Is_causal) && (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); || ((Is_causal) && (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
if (is_final_write) { if (is_final_write) {
...@@ -702,7 +702,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -702,7 +702,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
__syncthreads(); __syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS]; uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out); smem_dv.load(dv_out);
Gmem_tile_dv gmem_dv(params.dqkv_ptr + 2 * params.h * params.d * 2, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, binfo, tidx, false);
if (!Is_first) { if (!Is_first) {
gmem_dv.move(loop_step_idx); gmem_dv.move(loop_step_idx);
} }
...@@ -713,7 +713,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -713,7 +713,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) { // for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f); // dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// } // }
Gmem_tile_dk gmem_dk(params.dqkv_ptr + params.h * params.d * 2, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, binfo, tidx, false);
if (!Is_first) { if (!Is_first) {
gmem_dk.move(loop_step_idx); gmem_dk.move(loop_step_idx);
} }
...@@ -722,11 +722,11 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -722,11 +722,11 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// loop_steps = -1 means the number of steps will be params.s / Kernel_traits::Cta_tile_p::N. // loop_steps = -1 means the number of steps will be params.seqlen_k / Kernel_traits::Cta_tile_p::N.
// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2. // This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2.
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1, typename Params>
inline __device__ void compute_block_dq_dk_dv_1xN(const Params &params) { inline __device__ void compute_block_dq_dk_dv_1xN(const Params &params) {
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
// The block index for the batch. // The block index for the batch.
const int bidb = blockIdx.x; const int bidb = blockIdx.x;
...@@ -745,10 +745,10 @@ inline __device__ void compute_block_dq_dk_dv_1xN(const Params &params) { ...@@ -745,10 +745,10 @@ inline __device__ void compute_block_dq_dk_dv_1xN(const Params &params) {
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0); compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, 1); compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, 1);
} else { } else {
if (params.s == N_per_loop) { if (params.seqlen_k == blocksize_c) {
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0); compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
} else { } else {
const int max_loop_steps = (params.s + N_per_loop - 1) / N_per_loop; const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0); compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false>(params, ph, loop_step_idx); compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false>(params, ph, loop_step_idx);
......
...@@ -29,12 +29,12 @@ ...@@ -29,12 +29,12 @@
#include "fmha_block_fprop_kernel_1xN.h" #include "fmha_block_fprop_kernel_1xN.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
__global__ void fmha_block_fprop_fp16_sm80_loop_kernel(Fused_multihead_attention_fprop_params params) { __global__ void fmha_block_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
fmha::device_block_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params); fmha::device_block_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, void run_fmha_block_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) { const bool configure) {
bool is_causal = launch_params.params.is_causal; bool is_causal = launch_params.params.is_causal;
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way? // TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
...@@ -46,8 +46,8 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro ...@@ -46,8 +46,8 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, false>) ? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, false>)
: (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, false>)); : (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, false>));
constexpr int N = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
const int loop_steps = (launch_params.params.s + N - 1) / N; const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
// Don't need smem_size_softmax_lse if we're not looping // Don't need smem_size_softmax_lse if we're not looping
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>() const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
...@@ -60,7 +60,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro ...@@ -60,7 +60,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
if (configure) { if (configure) {
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>; using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
constexpr int M = Kernel_traits::Cta_tile_p::M; constexpr int M = Kernel_traits::Cta_tile_p::M;
size_t STEPS = (launch_params.params.s + M - 1) / M; size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M;
constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
...@@ -75,7 +75,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro ...@@ -75,7 +75,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaPeekAtLastError());
} }
void run_fmha_block_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) { const bool configure) {
if (launch_params.params.d == 16) { if (launch_params.params.d == 16) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>;
......
...@@ -97,7 +97,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c ...@@ -97,7 +97,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
Gemm1 gemm_q_k(smem_, tidx); Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true);
// Allocate the global memory tile loader for O. // Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
...@@ -122,9 +122,9 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c ...@@ -122,9 +122,9 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx); fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx, false);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx, false);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
...@@ -206,7 +206,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c ...@@ -206,7 +206,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("block_row_idx = %d\n", block_row_idx); // printf("block_row_idx = %d\n", block_row_idx);
// } // }
if (block_row_idx * Cta_tile_p::M >= binfo.actual_seqlen) break; if (block_row_idx * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
int mask_val_next = l < steps - 1 ? blockmask.mask_val(l + 1) : -1; int mask_val_next = l < steps - 1 ? blockmask.mask_val(l + 1) : -1;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
...@@ -443,7 +443,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c ...@@ -443,7 +443,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
const bool is_final_write = const bool is_final_write =
Is_last Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen) || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|| ((mask_val & 0x2) != 0) || ((mask_val & 0x2) != 0)
|| ((Is_causal) && (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); || ((Is_causal) && (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
...@@ -507,13 +507,14 @@ inline __device__ void device_block_1xN_loop(const Params &params) { ...@@ -507,13 +507,14 @@ inline __device__ void device_block_1xN_loop(const Params &params) {
auto seeds = at::cuda::philox::unpack(params.philox_args); auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds)); Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds));
const int STEPS = params.s / Kernel_traits::Cta_tile_p::M; constexpr int M = Kernel_traits::Cta_tile_p::M;
const int STEPS = (params.seqlen_q + M - 1) / M;
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.s == N_per_loop) { if (params.seqlen_k == blocksize_c) {
fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, ph0, ph1, 0); fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, ph0, ph1, 0);
} else { } else {
const int max_loop_steps = (params.s + N_per_loop - 1) / N_per_loop; const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, ph0, ph1, 0); fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, ph0, ph1, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, ph0, ph1, loop_step_idx); fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, ph0, ph1, loop_step_idx);
......
...@@ -42,7 +42,7 @@ struct Blockmask { ...@@ -42,7 +42,7 @@ struct Blockmask {
template<typename Params> template<typename Params>
__device__ Blockmask(const Params &params, int loop_step_idx) : __device__ Blockmask(const Params &params, int loop_step_idx) :
blockmask_ptr(params.blockmask + loop_step_idx * params.s / 16) { blockmask_ptr(params.blockmask + loop_step_idx * params.seqlen_q / 16) {
} }
__device__ int mask_val(int block_row_idx) const { __device__ int mask_val(int block_row_idx) const {
......
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
#include "fmha_dgrad_kernel_1xN_loop.h" #include "fmha_dgrad_kernel_1xN_loop.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(Fused_multihead_attention_fprop_params params) { __global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params); fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) { void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
...@@ -28,18 +28,18 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params ...@@ -28,18 +28,18 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
auto kernel = is_dropout auto kernel = is_dropout
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false>) ? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false>)
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false>); : (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false>);
constexpr int N = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.s == N) { if (params.seqlen_k == blocksize_c) {
kernel = is_dropout kernel = is_dropout
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/1> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/1>) ? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/1> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/1>)
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/1> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/1>); : (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/1> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/1>);
} else if (params.s == N * 2) { } else if (params.seqlen_k == blocksize_c * 2) {
kernel = is_dropout kernel = is_dropout
? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/2>) ? (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/2>)
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>); : (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
} }
// printf("N = %d, WARPS_N = %d, Smem size = %d\n", N, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv); // printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
if( smem_size_dq_dk_dv >= 48 * 1024 ) { if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute( FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
...@@ -49,12 +49,12 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params ...@@ -49,12 +49,12 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaPeekAtLastError());
} }
void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) { void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) {
if (params.d == 16) { if (params.d == 16) {
if( params.s == 128 ) { if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.s == 256 ) { } else if( params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else { } else {
...@@ -64,18 +64,18 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &para ...@@ -64,18 +64,18 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &para
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} }
} else if (params.d == 32) { } else if (params.d == 32) {
if( params.s == 128 ) { if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.s >= 256 ) { } else if( params.seqlen_k >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} }
} else if (params.d == 64) { } else if (params.d == 64) {
if( params.s == 128 ) { if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.s >= 256 ) { } else if( params.seqlen_k >= 256 ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major == 8 && dprops->minor == 0) { if (dprops->major == 8 && dprops->minor == 0) {
// Don't share smem for K & V, and don't keep V in registers // Don't share smem for K & V, and don't keep V in registers
...@@ -102,10 +102,10 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &para ...@@ -102,10 +102,10 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &para
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>; // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else { // } else {
// if( params.s == 128 ) { // if( params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>; // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); // run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if( params.s >= 256 ) { // } else if( params.seqlen_k >= 256 ) {
// if (dprops->major == 8 && dprops->minor == 0) { // if (dprops->major == 8 && dprops->minor == 0) {
// // Don't share smem for K & V, and don't keep V in registers // // Don't share smem for K & V, and don't keep V in registers
// // This speeds things up by 2-3% by avoiding register spills, but it // // This speeds things up by 2-3% by avoiding register spills, but it
......
...@@ -131,9 +131,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -131,9 +131,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true);
// Allocate the global memory tile loader for dQ. // Allocate the global memory tile loader for dQ.
Gmem_tile_dq gmem_dq(params.dqkv_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); Gmem_tile_dq gmem_dq(params.dq_ptr, params.dq_row_stride_in_elts, params.dq_head_stride_in_elts, binfo, tidx);
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S. // Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
...@@ -141,9 +141,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -141,9 +141,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx); fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx, false);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx, false);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V]; char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
...@@ -153,7 +153,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -153,7 +153,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx); Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for dO. // Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx, true);
// Allocate the shared memory tile loader for dO. // Allocate the shared memory tile loader for dO.
Smem_tile_do smem_do(&smem_[0], tidx); Smem_tile_do smem_do(&smem_[0], tidx);
Smem_tile_dot smem_dot(&smem_[0], tidx); Smem_tile_dot smem_dot(&smem_[0], tidx);
...@@ -165,7 +165,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -165,7 +165,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx); Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O. // Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx, true);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!! // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx); Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
...@@ -175,8 +175,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -175,8 +175,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
const int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0; const int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
// constexpr int steps = Cta_tile_p::N / Cta_tile_p::M; const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin;
const int steps = params.s / Cta_tile_p::M - begin;
// Wind gmem tiles to the correct position. // Wind gmem tiles to the correct position.
gmem_q.move(begin); gmem_q.move(begin);
...@@ -294,7 +293,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -294,7 +293,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Load over the entire sequence length. // Load over the entire sequence length.
for( int l = 0; l < steps; l++ ) { for( int l = 0; l < steps; l++ ) {
const int loop = (begin + l) * Cta_tile_p::M; const int loop = (begin + l) * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen ) if( loop >= binfo.actual_seqlen_q )
break; break;
// Load the fragments for V. // Load the fragments for V.
...@@ -584,7 +583,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -584,7 +583,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
const bool is_final_write = const bool is_final_write =
Is_last Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen) || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); || ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
if (is_final_write) { if (is_final_write) {
// if (Is_dropout) { // if (Is_dropout) {
...@@ -656,7 +655,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -656,7 +655,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
__syncthreads(); __syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS]; uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out); smem_dv.load(dv_out);
Gmem_tile_dv gmem_dv(params.dqkv_ptr + 2 * params.h * params.d * 2, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, binfo, tidx, false);
if (!Is_first) { if (!Is_first) {
gmem_dv.move(loop_step_idx); gmem_dv.move(loop_step_idx);
} }
...@@ -667,7 +666,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -667,7 +666,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) { // for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f); // dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// } // }
Gmem_tile_dk gmem_dk(params.dqkv_ptr + params.h * params.d * 2, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, binfo, tidx, false);
if (!Is_first) { if (!Is_first) {
gmem_dk.move(loop_step_idx); gmem_dk.move(loop_step_idx);
} }
...@@ -676,11 +675,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -676,11 +675,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// loop_steps = -1 means the number of steps will be params.s / Kernel_traits::Cta_tile_p::N. // loop_steps = -1 means the number of steps will be params.seqlen_k / Kernel_traits::Cta_tile_p::N.
// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2. // This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2.
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1, typename Params>
inline __device__ void compute_dq_dk_dv_1xN(const Params &params) { inline __device__ void compute_dq_dk_dv_1xN(const Params &params) {
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
// The block index for the batch. // The block index for the batch.
const int bidb = blockIdx.x; const int bidb = blockIdx.x;
...@@ -699,10 +698,10 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) { ...@@ -699,10 +698,10 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) {
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0); compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, 1); compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, 1);
} else { } else {
if (params.s == N_per_loop) { if (params.seqlen_k == blocksize_c) {
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0); compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
} else { } else {
const int max_loop_steps = (params.s + N_per_loop - 1) / N_per_loop; const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0); compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false>(params, ph, loop_step_idx); compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false>(params, ph, loop_step_idx);
......
...@@ -29,12 +29,12 @@ ...@@ -29,12 +29,12 @@
#include "fmha_fprop_kernel_1xN.h" #include "fmha_fprop_kernel_1xN.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
__global__ void fmha_fprop_fp16_sm80_loop_kernel(Fused_multihead_attention_fprop_params params) { __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params); fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) { const bool configure) {
bool is_causal = launch_params.params.is_causal; bool is_causal = launch_params.params.is_causal;
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way? // TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
...@@ -46,8 +46,8 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para ...@@ -46,8 +46,8 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
? (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, false>) ? (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, false>)
: (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, false>)); : (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, true> : &fmha_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, false>));
constexpr int N = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
const int loop_steps = (launch_params.params.s + N - 1) / N; const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
// Don't need smem_size_softmax_lse if we're not looping // Don't need smem_size_softmax_lse if we're not looping
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>() const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
...@@ -60,7 +60,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para ...@@ -60,7 +60,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
if (configure) { if (configure) {
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>; using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
constexpr int M = Kernel_traits::Cta_tile_p::M; constexpr int M = Kernel_traits::Cta_tile_p::M;
size_t STEPS = (launch_params.params.s + M - 1) / M; size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M;
constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
...@@ -75,13 +75,13 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para ...@@ -75,13 +75,13 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaPeekAtLastError());
} }
void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) { const bool configure) {
if (launch_params.params.d == 16) { if (launch_params.params.d == 16) {
if( launch_params.params.s == 128 ) { if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.s == 256 ) { } else if( launch_params.params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else { } else {
...@@ -91,10 +91,10 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l ...@@ -91,10 +91,10 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} }
} else if (launch_params.params.d == 32) { } else if (launch_params.params.d == 32) {
if( launch_params.params.s == 128 ) { if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.s == 256 ) { } else if( launch_params.params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else { } else {
...@@ -102,10 +102,10 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l ...@@ -102,10 +102,10 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} }
} else if (launch_params.params.d == 64) { } else if (launch_params.params.d == 64) {
if( launch_params.params.s == 128 ) { if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.s >= 256 ) { } else if( launch_params.params.seqlen_k >= 256 ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
if (dprops->major == 8 && dprops->minor >= 0) { if (dprops->major == 8 && dprops->minor >= 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
...@@ -121,7 +121,7 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l ...@@ -121,7 +121,7 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
} }
} }
} else if (launch_params.params.d == 128) { } else if (launch_params.params.d == 128) {
if( launch_params.params.s == 128 ) { if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else { } else {
...@@ -145,27 +145,27 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l ...@@ -145,27 +145,27 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } // }
// if (launch_params.params.d == 64) { // if (launch_params.params.d == 64) {
// if( launch_params.params.s == 128 ) { // if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if( launch_params.params.s >= 256 ) { // } else if( launch_params.params.seqlen_k >= 256 ) {
// auto dprops = at::cuda::getCurrentDeviceProperties(); // auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0) { // if (dprops->major == 8 && dprops->minor >= 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if (dprops->major == 7 && dprops->minor == 5) { // } else if (dprops->major == 7 && dprops->minor == 5) {
// if (launch_params.is_dropout) { // Need to use the same block size as backward // // if (launch_params.is_dropout) { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else { // // } else {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; // // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } // // }
// } // }
// } // }
// } // }
// if (launch_params.params.d == 128) { // if (launch_params.params.d == 128) {
// if( launch_params.params.s == 128 ) { // if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>; // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else { // } else {
......
...@@ -247,7 +247,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -247,7 +247,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
Gemm1 gemm_q_k(smem_, tidx); Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true);
// Allocate the global memory tile loader for O. // Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
...@@ -273,9 +273,9 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -273,9 +273,9 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx); fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx, false);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx, false);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
...@@ -354,7 +354,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -354,7 +354,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Load over the entire sequence length. // Load over the entire sequence length.
for( int l = 0; l < steps; l++ ) { for( int l = 0; l < steps; l++ ) {
if((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen) break; if((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
// Declare the accumulators for the 1st gemm. // Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
...@@ -575,7 +575,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -575,7 +575,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
const bool is_final_write = const bool is_final_write =
Is_last Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen) || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); || ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
#pragma unroll #pragma unroll
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
...@@ -631,13 +631,14 @@ inline __device__ void device_1xN_loop(const Params &params) { ...@@ -631,13 +631,14 @@ inline __device__ void device_1xN_loop(const Params &params) {
auto seeds = at::cuda::philox::unpack(params.philox_args); auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds)); Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds));
const int STEPS = params.s / Kernel_traits::Cta_tile_p::M; constexpr int M = Kernel_traits::Cta_tile_p::M;
const int STEPS = (params.seqlen_q + M - 1) / M;
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.s == N_per_loop) { if (params.seqlen_k == blocksize_c) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0);
} else { } else {
const int max_loop_steps = (params.s + N_per_loop - 1) / N_per_loop; const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx);
......
...@@ -51,20 +51,22 @@ struct BlockInfoPadded { ...@@ -51,20 +51,22 @@ struct BlockInfoPadded {
: bidb(bidb), bidh(bidh), h(params.h) { : bidb(bidb), bidh(bidh), h(params.h) {
// The block index. // The block index.
sum_s = params.cu_seqlens[bidb]; sum_s_k = params.cu_seqlens_k[bidb];
actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s; actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - sum_s_k;
bidx = sum_s * params.h + bidh; sum_s_q = params.cu_seqlens_q[bidb];
actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - sum_s_q;
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx; tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
} }
__device__ bool stop_early(const int start_col = 0) const { __device__ bool stop_early(const int start_col = 0) const {
return actual_seqlen <= start_col; return actual_seqlen_k <= start_col;
} }
int actual_seqlen; int actual_seqlen_q;
int bidx; int actual_seqlen_k;
int sum_s; int sum_s_q;
int sum_s_k;
int bidh; int bidh;
int bidb; int bidb;
int tidx_global; int tidx_global;
......
...@@ -5,7 +5,7 @@ import torch.nn as nn ...@@ -5,7 +5,7 @@ import torch.nn as nn
from einops import rearrange from einops import rearrange
from flash_attn.rotary import RotaryEmbedding, RotaryEmbedding2D from flash_attn.rotary import RotaryEmbedding, RotaryEmbedding2D
from flash_attn.flash_attn_interface import flash_attn_func from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
...@@ -13,15 +13,15 @@ class FlashAttention(nn.Module): ...@@ -13,15 +13,15 @@ class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax. """Implement the scaled dot product attention with softmax.
Arguments Arguments
--------- ---------
softmax_temp: The temperature to use for the softmax attention. softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at (default: 1/sqrt(d_keys) where d_keys is computed at
runtime) runtime)
attention_dropout: The dropout rate to apply to the attention attention_dropout: The dropout rate to apply to the attention
(default: 0.1) (default: 0.1)
""" """
def __init__(self, softmax_temp=None, attention_dropout=0.0, device=None, dtype=None): def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__() super().__init__()
self.softmax_temp = softmax_temp self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout self.dropout_p = attention_dropout
def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None, def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
...@@ -49,8 +49,10 @@ class FlashAttention(nn.Module): ...@@ -49,8 +49,10 @@ class FlashAttention(nn.Module):
max_s = seqlen max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device) device=qkv.device)
output = flash_attn_func(qkv, cu_seqlens, self.dropout_p if self.training else 0.0, output = flash_attn_unpadded_qkvpacked_func(
max_s, softmax_scale=self.softmax_temp, causal=causal) qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else: else:
key_padding_mask_bool = key_padding_mask.bool_matrix key_padding_mask_bool = key_padding_mask.bool_matrix
...@@ -58,17 +60,19 @@ class FlashAttention(nn.Module): ...@@ -58,17 +60,19 @@ class FlashAttention(nn.Module):
x = rearrange(qkv, 'b s three h d -> b s (three h d)') x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool) x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = flash_attn_func(x_unpad, cu_seqlens, output_unpad = flash_attn_unpadded_qkvpacked_func(
self.dropout_p if self.training else 0.0, x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal) softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen), indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads) 'b s (h d) -> b s h d', h=nheads)
else: else:
assert max_s is not None assert max_s is not None
output = flash_attn_func(qkv, cu_seqlens, output = flash_attn_unpadded_qkvpacked_func(
self.dropout_p if self.training else 0.0, qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
max_s, softmax_scale=self.softmax_temp, causal=causal) softmax_scale=self.softmax_scale, causal=causal
)
return output, None return output, None
......
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import torch import torch
import torch.nn as nn import torch.nn as nn
import flash_attn_cuda import flash_attn_cuda
def _flash_attn_forward(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal, return_softmax): def _get_block_size(device, head_dim, is_dropout):
context, softmax_lse, *rest = flash_attn_cuda.fwd(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, assert head_dim in [16, 32, 64, 128]
False, causal, return_softmax, None) if head_dim in [16, 32]:
# if context.isnan().any() or softmax_lse.isnan().any(): return 256
elif head_dim == 64:
return 128 if (torch.cuda.get_device_capability(device) == (7, 5) and is_dropout) else 256
elif head_dim == 128:
return 256 if (torch.cuda.get_device_capability(device) == (8, 0) and not is_dropout) else 128
def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax):
out, softmax_lse, *rest = flash_attn_cuda.fwd(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale,
False, causal, return_softmax, None
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint() # breakpoint()
S_dmask = rest[0] if return_softmax else None S_dmask = rest[0] if return_softmax else None
return context, softmax_lse, S_dmask return out, softmax_lse, S_dmask
def _flash_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p, max_s, def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
softmax_scale, causal): max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal):
dqkv, dp, softmax_d = flash_attn_cuda.bwd(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p, softmax_d = flash_attn_cuda.bwd(
softmax_scale, max_s, False, causal, None) dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
# if dqkv.isnan().any() or softmax_d.isnan().any(): max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint() # breakpoint()
return dqkv return dq, dk, dv, softmax_d
class FlashAttnFun(torch.autograd.Function): class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal): def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask # Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_attn_forward( out, softmax_lse, S_dmask = _flash_attn_forward(
qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=False qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
) )
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state) ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.max_s = max_s ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
return context return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
def backward(ctx, dout): def backward(ctx, dout, *args):
qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
if rng_state is not None: if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state() cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state) torch.cuda.set_rng_state(rng_state)
# S_dmask is None, temporarily use another tensor just to get it running dqkv = torch.empty_like(qkv)
dqkv = _flash_attn_backward( _flash_attn_backward(
dout, qkv, context, context, softmax_lse, cu_seqlens, ctx.dropout_p, dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
ctx.max_s, ctx.softmax_scale, ctx.causal dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal
) )
if rng_state is not None: if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state) torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None return dqkv, None, None, None, None, None, None
# We duplicate code to return both the output and the softmax for testing class FlashAttnKVPackedFunc(torch.autograd.Function):
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class FlashAttnFunWithS(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal): def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
# Save rng_state because the backward pass is gonna regenerate the dropout mask softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_attn_forward( out, softmax_lse, S_dmask = _flash_attn_forward(
qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=True q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
)
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
def backward(ctx, dout, *args):
q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
_flash_attn_backward(
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dq, dkv, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
) )
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state) ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.max_s = max_s ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
return context, S_dmask, softmax_lse return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored): def backward(ctx, dout, *args):
qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
if rng_state is not None: if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state() cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state) torch.cuda.set_rng_state(rng_state)
dqkv = _flash_attn_backward( dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, ctx.dropout_p, _flash_attn_backward(
ctx.max_s, ctx.softmax_scale, ctx.causal dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal
) )
if rng_state is not None: if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state) torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
causal, return_attn_probs)
def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale=None, causal=False,
return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
return_attn_probs)
def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale=None, causal=False, return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
v: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_attn_probs)
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False, def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
return_attn_probs=False): return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation """For backward-compatibility only, will remove soon.
dropout_p should be set to 0.0 during evaluation
""" """
func = FlashAttnFun if not return_attn_probs else FlashAttnFunWithS return flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_s, dropout_p, softmax_scale,
return func.apply(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal) causal, return_attn_probs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment