Commit 46fd2a20 authored by Tri Dao's avatar Tri Dao
Browse files

Support all head dims that are multiples of 8, up to 128

parent 97e13de2
...@@ -35,7 +35,7 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py ...@@ -35,7 +35,7 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
FlashAttention currently supports: FlashAttention currently supports:
1. Turing or Ampere GPUs (e.g., A100, RTX 3090, T4, RTX 2080). 1. Turing or Ampere GPUs (e.g., A100, RTX 3090, T4, RTX 2080).
2. fp16 and bf16 (bf16 requires Ampere GPUs). 2. fp16 and bf16 (bf16 requires Ampere GPUs).
3. Head dimensions 16, 32, 64, 128 (head dim 128 backward requires A100). 3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100.
Our tentative roadmap: Our tentative roadmap:
1. [Jun 2022] Make package pip-installable. 1. [Jun 2022] Make package pip-installable.
......
...@@ -232,7 +232,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -232,7 +232,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const int head_size = sizes[D_DIM]; const int head_size = sizes[D_DIM];
const int total_k = k.size(TOTAL_DIM); const int total_k = k.size(TOTAL_DIM);
TORCH_CHECK(batch_size > 0); TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(q, total_q, num_heads, head_size);
CHECK_SHAPE(k, total_k, num_heads, head_size); CHECK_SHAPE(k, total_k, num_heads, head_size);
...@@ -241,7 +241,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -241,7 +241,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
int blocksize_c = head_size == 128 ? 128 : 256; int blocksize_c = head_size > 64 ? 128 : 256;
// Need to round max_seqlen_k to multiples of blocksize_c // Need to round max_seqlen_k to multiples of blocksize_c
int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
if( max_seqlen_k_ <= 128 ) { if( max_seqlen_k_ <= 128 ) {
...@@ -386,8 +386,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -386,8 +386,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const int head_size = sizes[D_DIM]; const int head_size = sizes[D_DIM];
const int total_k = k.size(TOTAL_DIM); const int total_k = k.size(TOTAL_DIM);
TORCH_CHECK(batch_size > 0); TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
if (head_size == 128) { // TODO: eventually we should support SM86 and SM70 with d=128 as well if (head_size > 64) { // TODO: eventually we should support SM86 and SM70 with d=128 as well
TORCH_CHECK(is_sm80); TORCH_CHECK(is_sm80);
} }
...@@ -402,7 +402,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -402,7 +402,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
int blocksize_c = (head_size == 128 || (is_sm75 && head_size == 64)) ? 128 : 256; int blocksize_c = (head_size > 64 || (is_sm75 && head_size > 32)) ? 128 : 256;
int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
if( max_seqlen_k_ <= 128 ) { if( max_seqlen_k_ <= 128 ) {
max_seqlen_k = 128; max_seqlen_k = 128;
......
...@@ -81,11 +81,13 @@ struct Gmem_tile_qkv { ...@@ -81,11 +81,13 @@ struct Gmem_tile_qkv {
// Ctor. // Ctor.
template< typename BInfo > template< typename BInfo >
inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts, inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts,
const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx, bool use_seqlen_q) const uint32_t head_stride_in_elts, const int headdim,
const BInfo &binfo, const int tidx, bool use_seqlen_q)
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen(use_seqlen_q ? binfo.actual_seqlen_q : binfo.actual_seqlen_k) , actual_seqlen(use_seqlen_q ? binfo.actual_seqlen_q : binfo.actual_seqlen_k)
, ptr(reinterpret_cast<char *>(ptr_)) , ptr(reinterpret_cast<char *>(ptr_))
, tidx_(tidx) { , tidx_(tidx)
, col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_LDG / BYTES_PER_ELEMENT) < headdim) {
// Compute the position in the sequence (within the CTA for the moment). // Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW; int row = tidx / THREADS_PER_ROW;
...@@ -121,7 +123,7 @@ struct Gmem_tile_qkv { ...@@ -121,7 +123,7 @@ struct Gmem_tile_qkv {
for( int ii = 0; ii < LDGS; ++ii ) { for( int ii = 0; ii < LDGS; ++ii ) {
// ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes; // ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)); preds[ii] = col_predicate && ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));
fetch_[ii] = make_uint4(0, 0, 0, 0); fetch_[ii] = make_uint4(0, 0, 0, 0);
} }
...@@ -140,7 +142,7 @@ struct Gmem_tile_qkv { ...@@ -140,7 +142,7 @@ struct Gmem_tile_qkv {
for( int ii = 0; ii < LDGS; ++ii ) { for( int ii = 0; ii < LDGS; ++ii ) {
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes; // char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) { if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
fmha::stg(ptr_, data[ii]); fmha::stg(ptr_, data[ii]);
} }
} }
...@@ -154,7 +156,7 @@ struct Gmem_tile_qkv { ...@@ -154,7 +156,7 @@ struct Gmem_tile_qkv {
using elem2_type = typename std::conditional<std::is_same<elem_type, __half>::value, __half2, __nv_bfloat162>::type; using elem2_type = typename std::conditional<std::is_same<elem_type, __half>::value, __half2, __nv_bfloat162>::type;
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes; // char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
elem2_type *ptr_ = reinterpret_cast<elem2_type *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes); elem2_type *ptr_ = reinterpret_cast<elem2_type *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) { if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
#pragma unroll #pragma unroll
for (int jj = 0; jj < 4; ++jj) { for (int jj = 0; jj < 4; ++jj) {
atomicAdd(ptr_ + jj, reinterpret_cast<const elem2_type(&)[4]>(data[ii])[jj]); atomicAdd(ptr_ + jj, reinterpret_cast<const elem2_type(&)[4]>(data[ii])[jj]);
...@@ -172,7 +174,7 @@ struct Gmem_tile_qkv { ...@@ -172,7 +174,7 @@ struct Gmem_tile_qkv {
for( int ii = 0; ii < LDGS; ++ii ) { for( int ii = 0; ii < LDGS; ++ii ) {
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes; // char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
float *ptr_ = reinterpret_cast<float *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes); float *ptr_ = reinterpret_cast<float *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) { if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
#pragma unroll #pragma unroll
for (int jj = 0; jj < 4; ++jj) { for (int jj = 0; jj < 4; ++jj) {
const float2 data_f = fmha::half2_unpack<__half>(reinterpret_cast<const uint32_t(&)[4]>(data[ii])[jj]); const float2 data_f = fmha::half2_unpack<__half>(reinterpret_cast<const uint32_t(&)[4]>(data[ii])[jj]);
...@@ -201,6 +203,7 @@ struct Gmem_tile_qkv { ...@@ -201,6 +203,7 @@ struct Gmem_tile_qkv {
const int tidx_; const int tidx_;
// The length of the sequence loaded by that memory tile. // The length of the sequence loaded by that memory tile.
int actual_seqlen; int actual_seqlen;
const bool col_predicate;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -246,11 +249,13 @@ struct Gmem_tile_o { ...@@ -246,11 +249,13 @@ struct Gmem_tile_o {
template<typename BInfo> template<typename BInfo>
// inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx) // inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx)
inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts, inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts,
const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx) const uint32_t head_stride_in_elts, const int headdim,
const BInfo &binfo, const int tidx)
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
, actual_seqlen_q(binfo.actual_seqlen_q) , actual_seqlen_q(binfo.actual_seqlen_q)
, ptr_(reinterpret_cast<char *>(ptr)) , ptr_(reinterpret_cast<char *>(ptr))
, tidx_(tidx) { , tidx_(tidx)
, col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_STG / BYTES_PER_ELEMENT) < headdim) {
// Compute the position in the sequence (within the CTA for the moment). // Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW; int row = tidx / THREADS_PER_ROW;
...@@ -280,7 +285,7 @@ struct Gmem_tile_o { ...@@ -280,7 +285,7 @@ struct Gmem_tile_o {
#pragma unroll #pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii; int jj = mi * STGS_PER_LOOP + ii;
if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q ) { if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
break; break;
} }
...@@ -308,7 +313,7 @@ struct Gmem_tile_o { ...@@ -308,7 +313,7 @@ struct Gmem_tile_o {
#pragma unroll #pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii; int jj = mi * STGS_PER_LOOP + ii;
if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q ) { if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
break; break;
} }
...@@ -335,6 +340,7 @@ struct Gmem_tile_o { ...@@ -335,6 +340,7 @@ struct Gmem_tile_o {
// The length of the sequence loaded by that memory tile. // The length of the sequence loaded by that memory tile.
int actual_seqlen_q; int actual_seqlen_q;
const int tidx_; const int tidx_;
const bool col_predicate;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -138,19 +138,24 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -138,19 +138,24 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true); Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts,
params.d, binfo, tidx, true);
// Allocate the global memory tile loader for dQ. // Allocate the global memory tile loader for dQ.
Gmem_tile_dq gmem_dq(params.dq_ptr, params.dq_row_stride_in_elts, params.dq_head_stride_in_elts, binfo, tidx); Gmem_tile_dq gmem_dq(params.dq_ptr, params.dq_row_stride_in_elts, params.dq_head_stride_in_elts,
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); params.d, binfo, tidx);
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
params.d, binfo, tidx);
// Allocate the global memory tile loader for S. // Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx); fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx, false); Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts,
params.d, binfo, tidx, false);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx, false); Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts,
params.d, binfo, tidx, false);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V]; char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
...@@ -160,7 +165,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -160,7 +165,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx); Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for dO. // Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx, true); Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
params.d, binfo, tidx, true);
// Allocate the shared memory tile loader for dO. // Allocate the shared memory tile loader for dO.
Smem_tile_do smem_do(&smem_[0], tidx); Smem_tile_do smem_do(&smem_[0], tidx);
Smem_tile_dot smem_dot(&smem_[0], tidx); Smem_tile_dot smem_dot(&smem_[0], tidx);
...@@ -172,7 +178,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -172,7 +178,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx); Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O. // Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx, true); Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
params.d, binfo, tidx, true);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!! // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx); Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
...@@ -702,7 +709,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -702,7 +709,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
__syncthreads(); __syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS]; uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out); smem_dv.load(dv_out);
Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, binfo, tidx, false); Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts,
params.d, binfo, tidx, false);
if (!Is_first) { if (!Is_first) {
gmem_dv.move(loop_step_idx); gmem_dv.move(loop_step_idx);
} }
...@@ -713,7 +721,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -713,7 +721,8 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) { // for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f); // dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// } // }
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, binfo, tidx, false); Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts,
params.d, binfo, tidx, false);
if (!Is_first) { if (!Is_first) {
gmem_dk.move(loop_step_idx); gmem_dk.move(loop_step_idx);
} }
......
...@@ -97,10 +97,13 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c ...@@ -97,10 +97,13 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
Gemm1 gemm_q_k(smem_, tidx); Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true); Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts,
params.d, binfo, tidx, true);
// Allocate the global memory tile loader for O. // Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); params.d, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
params.d, binfo, tidx);
// Allocate the global memory tile loader for S. // Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
...@@ -122,9 +125,11 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c ...@@ -122,9 +125,11 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx); fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx, false); Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts,
params.d, binfo, tidx, false);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx, false); Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts,
params.d, binfo, tidx, false);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
......
...@@ -105,32 +105,19 @@ void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, co ...@@ -105,32 +105,19 @@ void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, co
// work around for MSVC issue // work around for MSVC issue
FP16_SWITCH(params.is_bf16, [&] { FP16_SWITCH(params.is_bf16, [&] {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
if (params.d == 16) { if (params.d <= 32) {
if( params.seqlen_k == 128 ) { if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if( params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else {
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
}
} else if (params.d == 32) {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if( params.seqlen_k >= 256 ) { } else if (params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} }
} else if (params.d == 64) { } else if (params.d <= 64) {
if( params.seqlen_k == 128 ) { if (params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if( params.seqlen_k >= 256 ) { } else if (params.seqlen_k >= 256) {
if (dprops->major == 8 && dprops->minor == 0) { if (dprops->major == 8 && dprops->minor == 0) {
// Don't share smem for K & V, and don't keep V in registers // Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it // This speeds things up by 2-3% by avoiding register spills, but it
...@@ -146,7 +133,7 @@ void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, co ...@@ -146,7 +133,7 @@ void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, co
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} }
} }
} else if (params.d == 128) { } else if (params.d <= 128) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} }
......
...@@ -144,19 +144,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -144,19 +144,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true); Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts,
params.d, binfo, tidx, true);
// Allocate the global memory tile loader for dQ. // Allocate the global memory tile loader for dQ.
Gmem_tile_dq gmem_dq(params.dq_ptr, params.dq_row_stride_in_elts, params.dq_head_stride_in_elts, binfo, tidx); Gmem_tile_dq gmem_dq(params.dq_ptr, params.dq_row_stride_in_elts, params.dq_head_stride_in_elts,
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); params.d, binfo, tidx);
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
params.d, binfo, tidx);
// Allocate the global memory tile loader for S. // Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx); fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx, false); Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts,
params.d, binfo, tidx, false);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx, false); Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts,
params.d, binfo, tidx, false);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V]; char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
...@@ -166,7 +171,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -166,7 +171,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx); Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for dO. // Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx, true); Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
params.d, binfo, tidx, true);
// Allocate the shared memory tile loader for dO. // Allocate the shared memory tile loader for dO.
Smem_tile_do smem_do(&smem_[0], tidx); Smem_tile_do smem_do(&smem_[0], tidx);
Smem_tile_dot smem_dot(&smem_[0], tidx); Smem_tile_dot smem_dot(&smem_[0], tidx);
...@@ -178,7 +184,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -178,7 +184,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx); Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O. // Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx, true); Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
params.d, binfo, tidx, true);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!! // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx); Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
...@@ -657,7 +664,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -657,7 +664,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
__syncthreads(); __syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS]; uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out); smem_dv.load(dv_out);
Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, binfo, tidx, false); Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts,
params.d, binfo, tidx, false);
// using Gmem_tile_dkv_accum = typename Kernel_traits::Gmem_tile_dkv_accum; // using Gmem_tile_dkv_accum = typename Kernel_traits::Gmem_tile_dkv_accum;
// Gmem_tile_dkv_accum gmem_dv_accum(params.dv_accum_ptr, params.h * params.d, params.d, binfo, tidx, false); // Gmem_tile_dkv_accum gmem_dv_accum(params.dv_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
// static_assert(Gmem_tile_dkv_accum::LDGS == Smem_tile_dv::NUM_LDS); // static_assert(Gmem_tile_dkv_accum::LDGS == Smem_tile_dv::NUM_LDS);
...@@ -674,7 +682,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -674,7 +682,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
uint4 dk_out[Smem_tile_dk::NUM_LDS]; uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out); smem_dk.load(dk_out);
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, binfo, tidx, false); Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts,
params.d, binfo, tidx, false);
// Gmem_tile_dkv_accum gmem_dk_accum(params.dk_accum_ptr, params.h * params.d, params.d, binfo, tidx, false); // Gmem_tile_dkv_accum gmem_dk_accum(params.dk_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
if (!Is_first) { if (!Is_first) {
gmem_dk.move(loop_step_idx); gmem_dk.move(loop_step_idx);
......
...@@ -114,36 +114,23 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params) { ...@@ -114,36 +114,23 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params) {
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params) { void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params) {
FP16_SWITCH(launch_params.params.is_bf16, [&] { FP16_SWITCH(launch_params.params.is_bf16, [&] {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
if (launch_params.params.d == 16) { if (launch_params.params.d <= 32) {
if( launch_params.params.seqlen_k == 128 ) { if (launch_params.params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} else if( launch_params.params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} else {
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
}
} else if (launch_params.params.d == 32) {
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} else if( launch_params.params.seqlen_k >= 256 ) { } else if (launch_params.params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} }
} else if (launch_params.params.d == 64) { } else if (launch_params.params.d <= 64) {
if( launch_params.params.seqlen_k == 128 ) { if (launch_params.params.seqlen_k == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} else if( launch_params.params.seqlen_k >= 256 ) { } else if (launch_params.params.seqlen_k >= 256) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} }
} else if (launch_params.params.d == 128) { } else if (launch_params.params.d <= 128) {
// TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory // TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
// to reduce register spilling. However, that increases the smem usage from ~41KB to ~105KB, // to reduce register spilling. However, that increases the smem usage from ~41KB to ~105KB,
// reducing occupancy (only 1 kernel can be scheduled per SM instead of 2). This strategy gives // reducing occupancy (only 1 kernel can be scheduled per SM instead of 2). This strategy gives
......
...@@ -259,10 +259,13 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -259,10 +259,13 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
Gemm1 gemm_q_k(smem_, tidx); Gemm1 gemm_q_k(smem_, tidx);
// Allocate the global memory tile loader for Q. // Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true); Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts,
params.d, binfo, tidx, true);
// Allocate the global memory tile loader for O. // Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_tmp_row_stride_in_elts, params.o_tmp_head_stride_in_elts, binfo, tidx); params.d, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_tmp_row_stride_in_elts,
params.o_tmp_head_stride_in_elts, params.d, binfo, tidx);
// Allocate the global memory tile loader for S. // Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_tile_s gmem_s(params, binfo, tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
...@@ -293,9 +296,11 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -293,9 +296,11 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx); fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// Allocate the global memory tile loader for K. // Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx, false); Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts,
params.d, binfo, tidx, false);
// Allocate the global memory tile loader for V. // Allocate the global memory tile loader for V.
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx, false); Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts,
params.d, binfo, tidx, false);
// The base pointer of smem_v; // The base pointer of smem_v;
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
......
...@@ -6,8 +6,8 @@ import flash_attn_cuda ...@@ -6,8 +6,8 @@ import flash_attn_cuda
def _get_block_size(device, head_dim, is_dropout): def _get_block_size(device, head_dim, is_dropout):
assert head_dim in [16, 32, 64, 128] assert head_dim % 8 == 0 and head_dim <= 128
return 256 if head_dim in [16, 32, 64] else 128 return 256 if head_dim <= 64 else 128
def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
......
...@@ -340,7 +340,7 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask ...@@ -340,7 +340,7 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [128, 64, 32, 16]) @pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize('seqlen', [128])
...@@ -362,8 +362,8 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -362,8 +362,8 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
...@@ -395,7 +395,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -395,7 +395,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
if is_sm80 or d < 128: # Only run backward for d=128 on A100 if is_sm80 or d <= 64: # Only run backward for d=128 on A100
g = torch.randn_like(output) g = torch.randn_like(output)
dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g) dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad) dqkv = dqkv_pad_fn(dqkv_unpad)
...@@ -421,7 +421,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -421,7 +421,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
else: else:
assert 0.98 <= dropout_fraction / dropout_p <= 1.02 assert 0.98 <= dropout_fraction / dropout_p <= 1.02
if is_sm80 or d < 128: # Only run backward for d=128 on A100 if is_sm80 or d <= 64: # Only run backward for d=128 on A100
# Error for dK and dV could be a bit higher if we're splitting along seqlen_q dimension # Error for dK and dV could be a bit higher if we're splitting along seqlen_q dimension
assert (dqkv - dqkv_ref).abs().max().item() <= 4 * (dqkv_pt - dqkv_ref).abs().max().item() assert (dqkv - dqkv_ref).abs().max().item() <= 4 * (dqkv_pt - dqkv_ref).abs().max().item()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol) # assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
...@@ -430,7 +430,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -430,7 +430,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('d', [128, 64, 32, 16]) @pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize('seqlen', [128])
...@@ -487,7 +487,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -487,7 +487,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
if is_sm80 or d < 128: # Only run backward for d=128 on A100 if is_sm80 or d <= 64: # Only run backward for d=128 on A100
g = torch.randn_like(output) g = torch.randn_like(output)
dq_unpad, dkv_unpad, = torch.autograd.grad(output, (q_unpad, kv_unpad), g) dq_unpad, dkv_unpad, = torch.autograd.grad(output, (q_unpad, kv_unpad), g)
dq = dq_pad_fn(dq_unpad) dq = dq_pad_fn(dq_unpad)
...@@ -512,7 +512,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -512,7 +512,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
else: else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01 assert 0.99 <= dropout_fraction / dropout_p <= 1.01
if is_sm80 or d < 128: # Only run backward for d=128 on A100 if is_sm80 or d <= 64: # Only run backward for d=128 on A100
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dkv - dkv_ref).abs().max().item() <= 2 * (dkv_pt - dkv_ref).abs().max().item() assert (dkv - dkv_ref).abs().max().item() <= 2 * (dkv_pt - dkv_ref).abs().max().item()
# assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol) # assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol)
...@@ -522,7 +522,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -522,7 +522,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('d', [128, 64, 32, 16]) @pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize('seqlen', [128])
...@@ -579,7 +579,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): ...@@ -579,7 +579,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
if is_sm80 or d < 128: # Only run backward for d=128 on A100 if is_sm80 or d <= 64: # Only run backward for d=128 on A100
g = torch.randn_like(output) g = torch.randn_like(output)
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output, (q_unpad, k_unpad, v_unpad), g) dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output, (q_unpad, k_unpad, v_unpad), g)
dq = dq_pad_fn(dq_unpad) dq = dq_pad_fn(dq_unpad)
...@@ -605,7 +605,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): ...@@ -605,7 +605,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
else: else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01 assert 0.99 <= dropout_fraction / dropout_p <= 1.01
if is_sm80 or d < 128: # Only run backward for d=128 on A100 if is_sm80 or d <= 64: # Only run backward for d=128 on A100
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
...@@ -618,7 +618,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): ...@@ -618,7 +618,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [128, 64, 32, 16]) @pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [512]) @pytest.mark.parametrize('seqlen', [512])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) @pytest.mark.parametrize('dropout_p', [0.0, 0.17])
...@@ -681,7 +681,7 @@ def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype): ...@@ -681,7 +681,7 @@ def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
if is_sm80 or d < 128: # Only run backward for d=128 on A100 if is_sm80 or d <= 64: # Only run backward for d=128 on A100
g = torch.randn_like(output) g = torch.randn_like(output)
dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g) dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad) dqkv = dqkv_pad_fn(dqkv_unpad)
...@@ -707,7 +707,7 @@ def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype): ...@@ -707,7 +707,7 @@ def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
else: else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01 assert 0.99 <= dropout_fraction / dropout_p <= 1.01
if is_sm80 or d < 128: # Only run backward for d=128 on A100 if is_sm80 or d <= 64: # Only run backward for d=128 on A100
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol) # assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
...@@ -715,7 +715,7 @@ def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype): ...@@ -715,7 +715,7 @@ def test_flash_attn_split(seqlen, d, dropout_p, causal, dtype):
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('d', [128, 64, 32, 16]) @pytest.mark.parametrize('d', [128, 64, 80, 40, 32, 16])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize('seqlen', [128])
...@@ -749,7 +749,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): ...@@ -749,7 +749,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
S_dmask_0, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal S_dmask_0, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
) )
if is_sm80 or d < 128: # Only run backward for d=128 on A100 if is_sm80 or d <= 64: # Only run backward for d=128 on A100
g = torch.randn_like(output_unpad_0) g = torch.randn_like(output_unpad_0)
dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0, dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0,
(q_unpad, k_unpad, v_unpad), g) (q_unpad, k_unpad, v_unpad), g)
...@@ -768,7 +768,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): ...@@ -768,7 +768,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
# assert torch.equal(sm_lse, sm_lse_0) # assert torch.equal(sm_lse, sm_lse_0)
assert torch.equal(S_dmask_converted, S_dmask_converted_0) assert torch.equal(S_dmask_converted, S_dmask_converted_0)
if is_sm80 or d < 128: # Only run backward for d=128 on A100 if is_sm80 or d <= 64: # Only run backward for d=128 on A100
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad, dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad,
(q_unpad, k_unpad, v_unpad), g) (q_unpad, k_unpad, v_unpad), g)
assert torch.equal(dq_unpad, dq_unpad_0) assert torch.equal(dq_unpad, dq_unpad_0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment