Commit 083e8f52 authored by Tri Dao's avatar Tri Dao
Browse files

Implement local attention


Co-authored-by: default avatarTimothee Lacroix <t@mistral.ai>
parent 4c8ff915
...@@ -40,7 +40,8 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -40,7 +40,8 @@ void set_params_fprop(Flash_fwd_params &params,
void *softmax_lse_d, void *softmax_lse_d,
float p_dropout, float p_dropout,
float softmax_scale, float softmax_scale,
bool is_causal) { int window_size_left,
int window_size_right) {
// Reset the parameters // Reset the parameters
memset(&params, 0, sizeof(params)); memset(&params, 0, sizeof(params));
...@@ -105,7 +106,15 @@ void set_params_fprop(Flash_fwd_params &params, ...@@ -105,7 +106,15 @@ void set_params_fprop(Flash_fwd_params &params,
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
TORCH_CHECK(p_dropout < 1.f); TORCH_CHECK(p_dropout < 1.f);
params.is_causal = is_causal; // Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
params.is_causal = window_size_left < 0 && window_size_right == 0;
if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
params.window_size_left = window_size_left;
params.window_size_right = window_size_right;
params.is_seqlens_k_cumulative = true; params.is_seqlens_k_cumulative = true;
} }
...@@ -138,7 +147,8 @@ void set_params_dgrad(Flash_bwd_params &params, ...@@ -138,7 +147,8 @@ void set_params_dgrad(Flash_bwd_params &params,
void *dsoftmax_sum_d, void *dsoftmax_sum_d,
float p_dropout, float p_dropout,
float softmax_scale, float softmax_scale,
bool is_causal) { int window_size_left,
int window_size_right) {
set_params_fprop(params, set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
...@@ -149,7 +159,8 @@ void set_params_dgrad(Flash_bwd_params &params, ...@@ -149,7 +159,8 @@ void set_params_dgrad(Flash_bwd_params &params,
softmax_lse_d, softmax_lse_d,
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal); window_size_left,
window_size_right);
// Set the pointers and strides. // Set the pointers and strides.
params.do_ptr = dout.data_ptr(); params.do_ptr = dout.data_ptr();
...@@ -242,6 +253,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -242,6 +253,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
const int window_size_left,
int window_size_right,
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
...@@ -281,10 +294,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -281,10 +294,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; }
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size_og % 8 == 0; const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0;
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k; const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
...@@ -353,7 +367,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -353,7 +367,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
softmax_lse.data_ptr(), softmax_lse.data_ptr(),
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal); window_size_left,
window_size_right);
// This needs to match with run_mha_fwd_splitkv_dispatch // This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
...@@ -421,9 +436,12 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -421,9 +436,12 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, const bool is_causal,
const int window_size_left,
int window_size_right,
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5; // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
...@@ -534,7 +552,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -534,7 +552,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
softmax_lse.data_ptr(), softmax_lse.data_ptr(),
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal); window_size_left,
window_size_right);
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
// state // state
...@@ -600,8 +619,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -600,8 +619,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool is_causal, const bool is_causal,
const int window_size_left,
int window_size_right,
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) { c10::optional<at::Tensor> &rng_state) {
if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5; // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
...@@ -748,7 +771,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -748,7 +771,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_d.data_ptr(), softmax_d.data_ptr(),
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal); window_size_left,
window_size_right);
auto launch = &run_mha_bwd; auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true); // launch(params, stream, /*configure=*/true);
...@@ -804,9 +828,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -804,9 +828,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, const bool is_causal,
const int window_size_left,
int window_size_right,
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state c10::optional<at::Tensor> &rng_state) {
) {
if (is_causal) { window_size_right = 0; }
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5; // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
...@@ -969,7 +996,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -969,7 +996,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_d.data_ptr(), softmax_d.data_ptr(),
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal); window_size_left,
window_size_right);
auto launch = &run_mha_bwd; auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true); // launch(params, stream, /*configure=*/true);
...@@ -1019,6 +1047,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1019,6 +1047,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
const int window_size_left,
int window_size_right,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits int num_splits
) { ) {
...@@ -1059,10 +1089,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1059,10 +1089,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; }
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size_og % 8 == 0; const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0;
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k; const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
...@@ -1125,7 +1156,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1125,7 +1156,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
softmax_lse.data_ptr(), softmax_lse.data_ptr(),
/*p_dropout=*/0.f, /*p_dropout=*/0.f,
softmax_scale, softmax_scale,
is_causal); window_size_left,
window_size_right);
at::Tensor k, v, k_padded, v_padded; at::Tensor k, v, k_padded, v_padded;
if (k_.has_value()) { if (k_.has_value()) {
......
...@@ -105,6 +105,9 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -105,6 +105,9 @@ struct Flash_fwd_params : public Qkv_params {
float rp_dropout; float rp_dropout;
float scale_softmax_rp_dropout; float scale_softmax_rp_dropout;
// Local window size
int window_size_left, window_size_right;
// Random state. // Random state.
at::PhiloxCudaState philox_args; at::PhiloxCudaState philox_args;
......
...@@ -422,7 +422,7 @@ inline __device__ void convert_dKV(const Params &params) { ...@@ -422,7 +422,7 @@ inline __device__ void convert_dKV(const Params &params) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) { inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -447,6 +447,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -447,6 +447,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return; if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
if (Is_local) {
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM));
}
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
...@@ -655,14 +658,53 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -655,14 +658,53 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded; tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;
int m_block = m_block_max - 1; int m_block = m_block_max - 1;
int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM); int m_block_min = (!Is_causal && !Is_local)
// We're guaranteed that m_block_min <= m_block: ? 0
: std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM);
// If not local, we're guaranteed that m_block_min <= m_block:
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case, // We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q. // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM. // So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM. // Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM. // So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop. // We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
// However, if local, then this possible to have some blocks of K & V not attending to any query.
// We might need to exit early and write 0 to dK and dV for those blocks.
// Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO.
if (Is_local && m_block < m_block_min) {
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
clear(tdKrdK);
clear(tdVrdV);
Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
return;
}
if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ
tQsQ.data() = tQsQ.data() + size(sQ); tQsQ.data() = tQsQ.data() + size(sQ);
...@@ -777,12 +819,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -777,12 +819,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// However, it's possible that the values in acc_s are so large that they overflow // However, it's possible that the values in acc_s are so large that they overflow
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
// So we need to mask out the elements beyond actual_seqlen_k. // So we need to mask out the elements beyond actual_seqlen_k.
if (!Is_causal) { if (!Is_causal && !Is_local) {
if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) { if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
flash::apply_mask(scores, binfo.actual_seqlen_k, flash::apply_mask(scores, binfo.actual_seqlen_k,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16); n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
} }
} else { } else if (Is_causal) {
// Putting this causal masking right after acc_s is *much* slower for some reason. // Putting this causal masking right after acc_s is *much* slower for some reason.
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking. // (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
...@@ -795,6 +837,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -795,6 +837,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16); AtomLayoutMS * 16);
} }
} else if (Is_local) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
|| (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
flash::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q, AtomLayoutMS * 16,
params.window_size_left, params.window_size_right);
}
} }
// if (cute::thread(32, 0)) { print(scores); } // if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value. // Compute the exponential value.
...@@ -1510,7 +1562,7 @@ inline __device__ void compute_dq_dk_dv(const Params &params) { ...@@ -1510,7 +1562,7 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) { inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
const int n_block = blockIdx.x; const int n_block = blockIdx.x;
...@@ -1519,7 +1571,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) { ...@@ -1519,7 +1571,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.z; const int bidh = blockIdx.z;
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block); compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -23,9 +23,10 @@ __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) { ...@@ -23,9 +23,10 @@ __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params); flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { __global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K>(params); static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K>
...@@ -62,16 +63,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -62,16 +63,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst>; // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>; // If Is_local, set Is_causal to false
if (smem_size_dq_dk_dv >= 48 * 1024) { auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
C10_CUDA_CHECK(cudaFuncSetAttribute( // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); if (smem_size_dq_dk_dv >= 48 * 1024) {
} C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
C10_CUDA_KERNEL_LAUNCH_CHECK(); }
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}); });
}); });
}); });
......
...@@ -71,7 +71,7 @@ inline __device__ void write_softmax_to_gmem( ...@@ -71,7 +71,7 @@ inline __device__ void write_softmax_to_gmem(
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) { inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -93,16 +93,17 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -93,16 +93,17 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal) { if (Is_causal || Is_local) {
n_block_max = std::min(n_block_max, n_block_max = std::min(n_block_max,
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// } // }
// We exit early and write 0 to gO and gLSE. // We exit early and write 0 to gO and gLSE.
// Otherwise we might read OOB elements from gK and gV. // Otherwise we might read OOB elements from gK and gV.
if (n_block_max <= 0) { if (n_block_max <= n_block_min) {
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might // Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state. // exit early and no one saves the rng state.
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
...@@ -145,6 +146,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -145,6 +146,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
return; return;
} }
} }
// if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
// We iterate over the blocks in reverse order. This is because the last block is the only one // We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
...@@ -326,9 +328,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -326,9 +328,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = !Is_causal constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1 ? 1
: (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
#pragma unroll #pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
...@@ -356,11 +358,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -356,11 +358,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print_tensor(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK // We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN. // can produce Inf / NaN.
if (!Is_causal) { if (!Is_causal && !Is_local) {
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else { } else {
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) // Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
...@@ -374,18 +376,21 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -374,18 +376,21 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Idk why it's get<1> and not get<0> of the stride. // Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row // I can't get the stride from idx_row
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, flash::apply_mask_local</*HasWSLeft=*/Is_local>(
// m_block * kBlockM + get<0>(idx_row(0)), scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, // m_block * kBlockM + get<0>(idx_row(0)),
binfo.actual_seqlen_q, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
kNWarps * 16); binfo.actual_seqlen_q, kNWarps * 16,
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); params.window_size_left, params.window_size_right
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
);
// if (cute::thread0()) { print_tensor(scores); }
} }
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (n_block > 0) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
...@@ -396,8 +401,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -396,8 +401,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf // TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0 masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Convert scores from fp32 to fp16/bf16 // Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores); Tensor rP = flash::convert_type<Element>(scores);
...@@ -426,14 +431,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -426,14 +431,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration // This check is at the end of the loop since we always have at least 1 iteration
if (n_masking_steps > 1 && n_block <= 0) { if (n_masking_steps > 1 && n_block <= n_block_min) {
--n_block; --n_block;
break; break;
} }
} }
// These are the iterations where we don't need masking on S // These are the iterations where we don't need masking on S
for (; n_block >= 0; --n_block) { for (; n_block >= n_block_min; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s); clear(acc_s);
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
...@@ -450,7 +455,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -450,7 +455,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (n_block > 0) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
...@@ -461,7 +466,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -461,7 +466,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores); Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
...@@ -568,7 +581,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -568,7 +581,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -599,11 +612,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -599,11 +612,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (m_block * kBlockM >= binfo.actual_seqlen_q) return; if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
const int n_block_min = n_split_idx * n_blocks_per_split; const int n_block_min = !Is_local
? n_split_idx * n_blocks_per_split
: std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
if (Is_causal) { if (Is_causal || Is_local) {
n_block_max = std::min(n_block_max, n_block_max = std::min(n_block_max,
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
} }
if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0
// We exit early and write 0 to gOaccum and -inf to gLSEaccum. // We exit early and write 0 to gOaccum and -inf to gLSEaccum.
...@@ -842,21 +857,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -842,21 +857,21 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM); binfo.actual_seqlen_q - m_block * kBlockM);
} else { } else {
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0. // We do this by setting the row stride of gCos / gSin to 0.
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockM>, Int<kHeadDim / 2>>{}, Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{})); make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockM>, Int<kHeadDim / 2>>{}, Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{})); make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{})); make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(Is_causal ? params.rotary_dim / 2 : 0, _1{})); make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
...@@ -895,9 +910,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -895,9 +910,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = !Is_causal constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1 ? 1
: (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
#pragma unroll #pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
...@@ -929,13 +944,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -929,13 +944,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// We don't put the masking before the matmul S = Q K^T because we don't clear sK // We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN. // can produce Inf / NaN.
if (!Is_causal) { if (!Is_causal && !Is_local) {
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else { } else {
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q, kNWarps * 16,
kNWarps * 16); params.window_size_left, params.window_size_right
);
} }
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
...@@ -954,8 +970,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -954,8 +970,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// We have key_padding_mask so we'll need to Check_inf // We have key_padding_mask so we'll need to Check_inf
masking_step == 0 masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert scores from fp32 to fp16/bf16 // Convert scores from fp32 to fp16/bf16
...@@ -1003,7 +1019,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1003,7 +1019,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores); Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
...@@ -1106,7 +1130,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1106,7 +1130,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) { inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
...@@ -1122,12 +1146,12 @@ inline __device__ void compute_attn(const Params &params) { ...@@ -1122,12 +1146,12 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of // the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block); flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
inline __device__ void compute_attn_splitkv(const Params &params) { inline __device__ void compute_attn_splitkv(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
...@@ -1136,7 +1160,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) { ...@@ -1136,7 +1160,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
const int n_split_idx = Split ? blockIdx.y : 0; const int n_split_idx = Split ? blockIdx.y : 0;
const int num_n_splits = Split ? gridDim.y : 1; const int num_n_splits = Split ? gridDim.y : 1;
flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits); flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -10,14 +10,15 @@ ...@@ -10,14 +10,15 @@
#include "flash.h" #include "flash.h"
#include "flash_fwd_kernel.h" #include "flash_fwd_kernel.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) { __global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Return_softmax>(params); static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params);
} }
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Split, Append_KV>(params); flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params);
} }
template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K> template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K>
...@@ -42,23 +43,25 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -42,23 +43,25 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const bool return_softmax = params.p_ptr != nullptr; const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
// Will only return softmax if dropout, to reduce compilation time. BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // Will only return softmax if dropout, to reduce compilation time.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If return_softmax, set IsEvenMNConst to false to reduce number of templates
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst && IsEvenKConst && (!ReturnSoftmaxConst) && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>; // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>; // If Is_local, set Is_causal to false
if (smem_size >= 48 * 1024) { auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal && !Is_local, Is_local, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
C10_CUDA_CHECK(cudaFuncSetAttribute( if (smem_size >= 48 * 1024) {
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); C10_CUDA_CHECK(cudaFuncSetAttribute(
} kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// int ctas_per_sm; }
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // int ctas_per_sm;
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
C10_CUDA_KERNEL_LAUNCH_CHECK(); kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}); });
}); });
}); });
...@@ -76,19 +79,22 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -76,19 +79,22 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.window_size_left >= 0 || params.window_size_right >= 0, Is_local, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>; // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>; // If Is_local, set Is_causal to false
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>; auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal && !Is_local, Is_local, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
if (smem_size >= 48 * 1024) { // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
C10_CUDA_CHECK(cudaFuncSetAttribute( // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); if (smem_size >= 48 * 1024) {
} C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
C10_CUDA_KERNEL_LAUNCH_CHECK(); }
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}); });
}); });
}); });
......
...@@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_ ...@@ -139,10 +139,11 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
} }
} }
template <typename Engine, typename Layout> template <bool HasWSLeft=true, typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_, inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int warp_row_stride) { const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
...@@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i ...@@ -155,14 +156,15 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
#pragma unroll #pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) { for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8; const int row_idx = row_idx_base + i * 8;
const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll #pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8; const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j; const int col_idx = col_idx_base + j;
if (col_idx >= col_idx_limit) { if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
} }
} }
...@@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i ...@@ -176,6 +178,15 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
} }
} }
template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_,
const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_,
max_seqlen_q, warp_row_stride, -1, 0);
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx( inline __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol, Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
......
...@@ -41,11 +41,21 @@ def _get_block_size(device, head_dim, is_dropout, is_causal): ...@@ -41,11 +41,21 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
return (128, 64) if is_sm80 else (64, 64) return (128, 64) if is_sm80 else (64, 64)
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax): def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None q,
k,
v,
None,
dropout_p,
softmax_scale,
causal,
window_size[0],
window_size[1],
return_softmax,
None,
) )
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
...@@ -61,6 +71,7 @@ def _flash_attn_varlen_forward( ...@@ -61,6 +71,7 @@ def _flash_attn_varlen_forward(
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal, causal,
window_size,
return_softmax, return_softmax,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
...@@ -78,6 +89,8 @@ def _flash_attn_varlen_forward( ...@@ -78,6 +89,8 @@ def _flash_attn_varlen_forward(
softmax_scale, softmax_scale,
False, False,
causal, causal,
window_size[0],
window_size[1],
return_softmax, return_softmax,
None, None,
) )
...@@ -87,7 +100,20 @@ def _flash_attn_varlen_forward( ...@@ -87,7 +100,20 @@ def _flash_attn_varlen_forward(
def _flash_attn_backward( def _flash_attn_backward(
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, rng_state=None dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
dropout_p,
softmax_scale,
causal,
window_size,
rng_state=None,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous # dq, dk, dv are allocated by us so they should already be contiguous
...@@ -105,6 +131,8 @@ def _flash_attn_backward( ...@@ -105,6 +131,8 @@ def _flash_attn_backward(
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal, causal,
window_size[0],
window_size[1],
None, None,
rng_state, rng_state,
) )
...@@ -128,6 +156,7 @@ def _flash_attn_varlen_backward( ...@@ -128,6 +156,7 @@ def _flash_attn_varlen_backward(
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal, causal,
window_size,
rng_state=None, rng_state=None,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
...@@ -151,6 +180,8 @@ def _flash_attn_varlen_backward( ...@@ -151,6 +180,8 @@ def _flash_attn_varlen_backward(
softmax_scale, softmax_scale,
False, False,
causal, causal,
window_size[0],
window_size[1],
None, None,
rng_state, rng_state,
) )
...@@ -161,7 +192,7 @@ def _flash_attn_varlen_backward( ...@@ -161,7 +192,7 @@ def _flash_attn_varlen_backward(
class FlashAttnQKVPackedFunc(torch.autograd.Function): class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
...@@ -171,12 +202,14 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -171,12 +202,14 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -197,15 +230,26 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -197,15 +230,26 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.dropout_p, ctx.dropout_p,
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size,
rng_state=rng_state, rng_state=rng_state,
) )
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None return dqkv, None, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): def forward(
ctx,
qkv,
cu_seqlens,
max_seqlen,
dropout_p,
softmax_scale,
causal,
window_size,
return_softmax,
):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
...@@ -219,6 +263,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -219,6 +263,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
...@@ -226,6 +271,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -226,6 +271,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.max_seqlen = max_seqlen ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -250,15 +296,16 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -250,15 +296,16 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.dropout_p, ctx.dropout_p,
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size,
rng_state=rng_state, rng_state=rng_state,
) )
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
...@@ -268,12 +315,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -268,12 +315,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -295,11 +344,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -295,11 +344,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.dropout_p, ctx.dropout_p,
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None return dq, dkv, None, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...@@ -315,6 +365,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -315,6 +365,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal, causal,
window_size,
return_softmax, return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
...@@ -330,6 +381,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -330,6 +381,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward( ctx.save_for_backward(
...@@ -340,6 +392,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -340,6 +392,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.max_seqlen_k = max_seqlen_k ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -365,16 +418,17 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -365,16 +418,17 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.dropout_p, ctx.dropout_p,
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax): def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
...@@ -384,12 +438,14 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -384,12 +438,14 @@ class FlashAttnFunc(torch.autograd.Function):
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -409,12 +465,13 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -409,12 +465,13 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.dropout_p, ctx.dropout_p,
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function): class FlashAttnVarlenFunc(torch.autograd.Function):
...@@ -431,6 +488,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -431,6 +488,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal, causal,
window_size,
return_softmax, return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
...@@ -446,6 +504,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -446,6 +504,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward( ctx.save_for_backward(
...@@ -456,6 +515,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -456,6 +515,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.max_seqlen_k = max_seqlen_k ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -479,16 +539,22 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -479,16 +539,22 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.dropout_p, ctx.dropout_p,
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func( def flash_attn_qkvpacked_func(
qkv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False qkv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than If Q, K, V are already stacked into 1 tensor, this function will be faster than
...@@ -497,12 +563,16 @@ def flash_attn_qkvpacked_func( ...@@ -497,12 +563,16 @@ def flash_attn_qkvpacked_func(
For multi-query and grouped-query attention (MQA/GQA), please see For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_kvpacked_func and flash_attn_func. flash_attn_kvpacked_func and flash_attn_func.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments: Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim) qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability. dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax. softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -515,11 +585,19 @@ def flash_attn_qkvpacked_func( ...@@ -515,11 +585,19 @@ def flash_attn_qkvpacked_func(
The output of softmax (possibly with different scaling). It also encodes the dropout The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs) return FlashAttnQKVPackedFunc.apply(
qkv, dropout_p, softmax_scale, causal, window_size, return_attn_probs
)
def flash_attn_kvpacked_func( def flash_attn_kvpacked_func(
q, kv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False q,
kv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than If K, V are already stacked into 1 tensor, this function will be faster than
...@@ -542,6 +620,10 @@ def flash_attn_kvpacked_func( ...@@ -542,6 +620,10 @@ def flash_attn_kvpacked_func(
1 1 1 1
If the row of the mask is all zero, the output will be zero. If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments: Arguments:
q: (batch_size, seqlen, nheads, headdim) q: (batch_size, seqlen, nheads, headdim)
kv: (batch_size, seqlen, 2, nheads_k, headdim) kv: (batch_size, seqlen, 2, nheads_k, headdim)
...@@ -549,6 +631,7 @@ def flash_attn_kvpacked_func( ...@@ -549,6 +631,7 @@ def flash_attn_kvpacked_func(
softmax_scale: float. The scaling of QK^T before applying softmax. softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -561,11 +644,20 @@ def flash_attn_kvpacked_func( ...@@ -561,11 +644,20 @@ def flash_attn_kvpacked_func(
The output of softmax (possibly with different scaling). It also encodes the dropout The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs) return FlashAttnKVPackedFunc.apply(
q, kv, dropout_p, softmax_scale, causal, window_size, return_attn_probs
)
def flash_attn_func( def flash_attn_func(
q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
...@@ -585,6 +677,10 @@ def flash_attn_func( ...@@ -585,6 +677,10 @@ def flash_attn_func(
1 1 1 1
If the row of the mask is all zero, the output will be zero. If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments: Arguments:
q: (batch_size, seqlen, nheads, headdim) q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim) k: (batch_size, seqlen, nheads_k, headdim)
...@@ -593,6 +689,7 @@ def flash_attn_func( ...@@ -593,6 +689,7 @@ def flash_attn_func(
softmax_scale: float. The scaling of QK^T before applying softmax. softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -605,7 +702,9 @@ def flash_attn_func( ...@@ -605,7 +702,9 @@ def flash_attn_func(
The output of softmax (possibly with different scaling). It also encodes the dropout The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs) return FlashAttnFunc.apply(
q, k, v, dropout_p, softmax_scale, causal, window_size, return_attn_probs
)
def flash_attn_varlen_qkvpacked_func( def flash_attn_varlen_qkvpacked_func(
...@@ -615,6 +714,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -615,6 +714,7 @@ def flash_attn_varlen_qkvpacked_func(
dropout_p=0.0, dropout_p=0.0,
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -624,6 +724,9 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -624,6 +724,9 @@ def flash_attn_varlen_qkvpacked_func(
For multi-query and grouped-query attention (MQA/GQA), please see For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments: Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
...@@ -633,6 +736,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -633,6 +736,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale: float. The scaling of QK^T before applying softmax. softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -646,7 +750,14 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -646,7 +750,14 @@ def flash_attn_varlen_qkvpacked_func(
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnVarlenQKVPackedFunc.apply( return FlashAttnVarlenQKVPackedFunc.apply(
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs qkv,
cu_seqlens,
max_seqlen,
dropout_p,
softmax_scale,
causal,
window_size,
return_attn_probs,
) )
...@@ -660,6 +771,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -660,6 +771,7 @@ def flash_attn_varlen_kvpacked_func(
dropout_p=0.0, dropout_p=0.0,
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -683,6 +795,10 @@ def flash_attn_varlen_kvpacked_func( ...@@ -683,6 +795,10 @@ def flash_attn_varlen_kvpacked_func(
1 1 1 1
If the row of the mask is all zero, the output will be zero. If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments: Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...@@ -696,6 +812,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -696,6 +812,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale: float. The scaling of QK^T before applying softmax. softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -718,6 +835,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -718,6 +835,7 @@ def flash_attn_varlen_kvpacked_func(
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal, causal,
window_size,
return_attn_probs, return_attn_probs,
) )
...@@ -733,6 +851,7 @@ def flash_attn_varlen_func( ...@@ -733,6 +851,7 @@ def flash_attn_varlen_func(
dropout_p=0.0, dropout_p=0.0,
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -753,6 +872,10 @@ def flash_attn_varlen_func( ...@@ -753,6 +872,10 @@ def flash_attn_varlen_func(
1 1 1 1
If the row of the mask is all zero, the output will be zero. If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments: Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
...@@ -767,6 +890,7 @@ def flash_attn_varlen_func( ...@@ -767,6 +890,7 @@ def flash_attn_varlen_func(
softmax_scale: float. The scaling of QK^T before applying softmax. softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -790,6 +914,7 @@ def flash_attn_varlen_func( ...@@ -790,6 +914,7 @@ def flash_attn_varlen_func(
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal, causal,
window_size,
return_attn_probs, return_attn_probs,
) )
...@@ -805,6 +930,7 @@ def flash_attn_with_kvcache( ...@@ -805,6 +930,7 @@ def flash_attn_with_kvcache(
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window
rotary_interleaved=True, rotary_interleaved=True,
num_splits=0, num_splits=0,
): ):
...@@ -818,11 +944,12 @@ def flash_attn_with_kvcache( ...@@ -818,11 +944,12 @@ def flash_attn_with_kvcache(
For example, the KV cache could be pre-allocated with the max sequence length, and you can use For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
cache_seqlens + 1, etc. If not causal, the query @q will be rotated by rotary_cos and rotary_sin and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
...@@ -843,6 +970,10 @@ def flash_attn_with_kvcache( ...@@ -843,6 +970,10 @@ def flash_attn_with_kvcache(
1 1 1 1
If the row of the mask is all zero, the output will be zero. If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Note: Does not support backward pass. Note: Does not support backward pass.
Arguments: Arguments:
...@@ -860,6 +991,7 @@ def flash_attn_with_kvcache( ...@@ -860,6 +991,7 @@ def flash_attn_with_kvcache(
softmax_scale: float. The scaling of QK^T before applying softmax. softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
...@@ -894,6 +1026,8 @@ def flash_attn_with_kvcache( ...@@ -894,6 +1026,8 @@ def flash_attn_with_kvcache(
None, None,
softmax_scale, softmax_scale,
causal, causal,
window_size[0],
window_size[1],
rotary_interleaved, rotary_interleaved,
num_splits, num_splits,
) )
......
...@@ -150,8 +150,13 @@ def generate_qkv( ...@@ -150,8 +150,13 @@ def generate_qkv(
) )
def construct_causal_mask( def construct_local_mask(
seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, device=None seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
query_padding_mask=None,
key_padding_mask=None,
device=None,
): ):
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
...@@ -165,7 +170,14 @@ def construct_causal_mask( ...@@ -165,7 +170,14 @@ def construct_causal_mask(
if query_padding_mask is None if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
) )
return col_idx > row_idx + sk - sq if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)
def attention_ref( def attention_ref(
...@@ -177,6 +189,7 @@ def attention_ref( ...@@ -177,6 +189,7 @@ def attention_ref(
dropout_p=0.0, dropout_p=0.0,
dropout_mask=None, dropout_mask=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True, upcast=True,
reorder_ops=False, reorder_ops=False,
): ):
...@@ -189,6 +202,8 @@ def attention_ref( ...@@ -189,6 +202,8 @@ def attention_ref(
key_padding_mask: (batch_size, seqlen_k) key_padding_mask: (batch_size, seqlen_k)
dropout_p: float dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16. output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
...@@ -198,6 +213,8 @@ def attention_ref( ...@@ -198,6 +213,8 @@ def attention_ref(
output: (batch_size, seqlen_q, nheads, head_dim) output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
""" """
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype dtype_og = q.dtype
if upcast: if upcast:
q, k, v = q.float(), k.float(), v.float() q, k, v = q.float(), k.float(), v.float()
...@@ -211,17 +228,24 @@ def attention_ref( ...@@ -211,17 +228,24 @@ def attention_ref(
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if key_padding_mask is not None: if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if causal: if window_size[0] >= 0 or window_size[1] >= 0:
# causal_mask = torch.triu( local_mask = construct_local_mask(
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 seqlen_q,
# ) seqlen_k,
causal_mask = construct_causal_mask( window_size,
seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device query_padding_mask,
key_padding_mask,
q.device,
) )
scores.masked_fill_(causal_mask, float("-inf")) scores.masked_fill_(local_mask, float("-inf"))
attention = torch.softmax(scores, dim=-1) attention = torch.softmax(scores, dim=-1)
if causal: # Some rows are completely masked out so we fill them with zero instead of NaN # Some rows might be completely masked out so we fill them with zero instead of NaN
attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p) dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v) # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
...@@ -232,7 +256,6 @@ def attention_ref( ...@@ -232,7 +256,6 @@ def attention_ref(
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None: if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
...@@ -244,6 +267,7 @@ def attention_kvpacked_ref( ...@@ -244,6 +267,7 @@ def attention_kvpacked_ref(
dropout_p=0.0, dropout_p=0.0,
dropout_mask=None, dropout_mask=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True, upcast=True,
reorder_ops=False, reorder_ops=False,
): ):
...@@ -257,6 +281,7 @@ def attention_kvpacked_ref( ...@@ -257,6 +281,7 @@ def attention_kvpacked_ref(
dropout_mask, dropout_mask,
upcast=upcast, upcast=upcast,
causal=causal, causal=causal,
window_size=window_size,
reorder_ops=reorder_ops, reorder_ops=reorder_ops,
) )
...@@ -267,6 +292,7 @@ def attention_qkvpacked_ref( ...@@ -267,6 +292,7 @@ def attention_qkvpacked_ref(
dropout_p=0.0, dropout_p=0.0,
dropout_mask=None, dropout_mask=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True, upcast=True,
reorder_ops=False, reorder_ops=False,
): ):
...@@ -280,6 +306,7 @@ def attention_qkvpacked_ref( ...@@ -280,6 +306,7 @@ def attention_qkvpacked_ref(
dropout_mask, dropout_mask,
upcast=upcast, upcast=upcast,
causal=causal, causal=causal,
window_size=window_size,
reorder_ops=reorder_ops, reorder_ops=reorder_ops,
) )
...@@ -327,7 +354,15 @@ def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask ...@@ -327,7 +354,15 @@ def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask
def convert_flash_attn_S_to_softmax( def convert_flash_attn_S_to_softmax(
S, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, head_dim, is_dropout, causal=False S,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
head_dim,
is_dropout,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
): ):
"""FlashAttention stores the S matrix in a different way. """FlashAttention stores the S matrix in a different way.
Arguments: Arguments:
...@@ -335,6 +370,8 @@ def convert_flash_attn_S_to_softmax( ...@@ -335,6 +370,8 @@ def convert_flash_attn_S_to_softmax(
query_padding_mask: (batch_size, seqlen_q_rounded) query_padding_mask: (batch_size, seqlen_q_rounded)
key_padding_mask: (batch_size, seqlen_k_rounded) key_padding_mask: (batch_size, seqlen_k_rounded)
""" """
if causal:
window_size = (window_size[0], 0)
seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
warps_n = 4 warps_n = 4
blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal) blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal)
...@@ -359,19 +396,21 @@ def convert_flash_attn_S_to_softmax( ...@@ -359,19 +396,21 @@ def convert_flash_attn_S_to_softmax(
four=4, four=4,
) )
if causal: if window_size[0] >= 0 or window_size[1] >= 0:
# causal_mask = torch.triu( local_mask = construct_local_mask(
# torch.ones(seqlen_q_rounded, seqlen_k_rounded, dtype=torch.bool, device=q.device), 1 seqlen_q,
# ) seqlen_k,
causal_mask = construct_causal_mask( window_size,
seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, S.device query_padding_mask,
key_padding_mask,
S.device,
) )
causal_mask = F.pad( local_mask = F.pad(
causal_mask, local_mask,
(0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
value=True, value=True,
) )
S_converted.masked_fill_(causal_mask, 0.0) S_converted.masked_fill_(local_mask, 0.0)
# Need to zero out things not in attention_mask in case S was initialized with random values # Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten. # and some of those values aren't overwritten.
...@@ -399,6 +438,7 @@ def normalize_flash_attn_S( ...@@ -399,6 +438,7 @@ def normalize_flash_attn_S(
key_padding_mask=None, key_padding_mask=None,
is_dropout=False, is_dropout=False,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite window size
): ):
""" """
Arguments: Arguments:
...@@ -409,20 +449,24 @@ def normalize_flash_attn_S( ...@@ -409,20 +449,24 @@ def normalize_flash_attn_S(
softmax_lse: (batch_size, nheads, seqlen_q) softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q) softmax_max: (batch_size, nheads, seqlen_q)
""" """
if causal:
window_size = (window_size[0], 0)
q, k, v = q.float(), k.float(), v.float() q, k, v = q.float(), k.float(), v.float()
_, seqlen_q, _, head_dim = q.shape _, seqlen_q, _, head_dim = q.shape
seqlen_k = k.shape[1] seqlen_k = k.shape[1]
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k) scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
if key_padding_mask is not None: if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if causal: if window_size[0] >= 0 or window_size[1] >= 0:
# causal_mask = torch.triu( local_mask = construct_local_mask(
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 seqlen_q,
# ) seqlen_k,
causal_mask = construct_causal_mask( window_size,
seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device query_padding_mask,
key_padding_mask,
q.device,
) )
scores.masked_fill_(causal_mask, float("-inf")) scores.masked_fill_(local_mask, float("-inf"))
_, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal) _, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal)
scores_block = scores.split(block_size_n, dim=-1) scores_block = scores.split(block_size_n, dim=-1)
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
...@@ -446,79 +490,84 @@ def normalize_flash_attn_S( ...@@ -446,79 +490,84 @@ def normalize_flash_attn_S(
def get_dropout_fraction( def get_dropout_fraction(
dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False dropout_mask,
query_padding_mask=None,
key_padding_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
): ):
""" """
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop. dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q) query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k) key_padding_mask: (batch_size, seqlen_k)
""" """
if causal:
window_size = (window_size[0], 0)
batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
dropped = ~dropout_mask dropped = ~dropout_mask
valid = torch.ones_like(dropout_mask)
if query_padding_mask is not None: if query_padding_mask is not None:
dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
if key_padding_mask is not None: if key_padding_mask is not None:
dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
if causal: valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
# causal_mask = torch.triu( if window_size[0] >= 0 or window_size[1] >= 0:
# torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1 local_mask = construct_local_mask(
# ) seqlen_q,
causal_mask = construct_causal_mask( seqlen_k,
seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, dropout_mask.device window_size,
query_padding_mask,
key_padding_mask,
dropout_mask.device,
) )
dropped.masked_fill_(causal_mask, False) dropped.masked_fill_(local_mask, False)
valid.masked_fill_(local_mask, False)
dropped_total = dropped.sum() dropped_total = dropped.sum()
query_lengths = ( return dropped.sum() / valid.sum()
query_padding_mask.sum(dim=-1)
if query_padding_mask is not None
else torch.full((batch_size,), seqlen_q, device=dropout_mask.device)
)
key_lengths = (
key_padding_mask.sum(dim=-1)
if key_padding_mask is not None
else torch.full((batch_size,), seqlen_k, device=dropout_mask.device)
)
if not causal:
numel_per_batch = query_lengths * key_lengths
else:
numel_per_batch = torch.where(
key_lengths <= query_lengths,
key_lengths * (key_lengths + 1) / 2,
query_lengths * key_lengths - (query_lengths * (query_lengths - 1) / 2),
)
return dropped_total / (numel_per_batch.sum() * nheads)
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128]) # @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize("d", [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize("seqlen", [128])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 16 batch_size = 13
nheads = 9 nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
qkv = torch.randn( qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
) )
out, lse, S_dmask = flash_attn_qkvpacked_func( out, lse, S_dmask = flash_attn_qkvpacked_func(
qkv, dropout_p, return_attn_probs=True, causal=causal qkv, dropout_p, causal=causal, window_size=window_size, return_attn_probs=True
) )
if dropout_p > 0.0: if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, seqlen, seqlen, None, None, d, dropout_p > 0.0, causal=causal S_dmask,
seqlen,
seqlen,
None,
None,
d,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
) )
dropout_mask = S_dmask_converted >= 0 dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs() attn_unnorm = S_dmask_converted.abs()
...@@ -531,15 +580,27 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -531,15 +580,27 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
None, None,
dropout_p > 0.0, dropout_p > 0.0,
causal=causal, causal=causal,
window_size=window_size,
) )
dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item() dropout_fraction = get_dropout_fraction(
dropout_mask, None, None, causal=causal, window_size=window_size
).item()
print(f"Actual dropout fraction: {dropout_fraction}") print(f"Actual dropout fraction: {dropout_fraction}")
else: else:
dropout_mask = None dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref(qkv, None, dropout_p, dropout_mask, causal=causal) out_ref, attn_ref = attention_qkvpacked_ref(
qkv, None, dropout_p, dropout_mask, causal=causal, window_size=window_size
)
out_pt, attn_pt = attention_qkvpacked_ref( out_pt, attn_pt = attention_qkvpacked_ref(
qkv, None, dropout_p, dropout_mask, causal=causal, upcast=False, reorder_ops=True qkv,
None,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
) )
# v = qkv[:, :, 2].float() # v = qkv[:, :, 2].float()
# qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float() # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float()
...@@ -590,7 +651,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -590,7 +651,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if dropout_p > 0.0: if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
assert abs(dropout_fraction - dropout_p) <= 0.01 assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
...@@ -598,15 +659,18 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -598,15 +659,18 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = "cuda" device = "cuda"
...@@ -614,6 +678,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -614,6 +678,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 5 batch_size = 5
nheads = 6 nheads = 6
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
qkv = torch.randn( qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
) )
...@@ -626,7 +691,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -626,7 +691,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
) )
out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func( out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func(
qkv_unpad, cu_seqlens, max_seqlen, dropout_p, return_attn_probs=True, causal=causal qkv_unpad,
cu_seqlens,
max_seqlen,
dropout_p,
causal=causal,
window_size=window_size,
return_attn_probs=True,
) )
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
if dropout_p > 0.0: if dropout_p > 0.0:
...@@ -639,6 +710,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -639,6 +710,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
d, d,
dropout_p > 0.0, dropout_p > 0.0,
causal=causal, causal=causal,
window_size=window_size,
) )
dropout_mask = S_dmask_converted >= 0 dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs() attn_unnorm = S_dmask_converted.abs()
...@@ -651,16 +723,17 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -651,16 +723,17 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
key_padding_mask, key_padding_mask,
dropout_p > 0.0, dropout_p > 0.0,
causal=causal, causal=causal,
window_size=window_size,
) )
dropout_fraction = get_dropout_fraction( dropout_fraction = get_dropout_fraction(
dropout_mask, key_padding_mask, key_padding_mask, causal=causal dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size
).item() ).item()
print(f"Actual dropout fraction: {dropout_fraction}") print(f"Actual dropout fraction: {dropout_fraction}")
else: else:
dropout_mask = None dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref( out_ref, attn_ref = attention_qkvpacked_ref(
qkv, key_padding_mask, dropout_p, dropout_mask, causal=causal qkv, key_padding_mask, dropout_p, dropout_mask, causal=causal, window_size=window_size
) )
out_pt, attn_pt = attention_qkvpacked_ref( out_pt, attn_pt = attention_qkvpacked_ref(
qkv, qkv,
...@@ -668,6 +741,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -668,6 +741,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
...@@ -700,7 +774,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -700,7 +774,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if dropout_p > 0.0: if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
assert abs(dropout_fraction - dropout_p) <= 0.01 assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
...@@ -712,10 +786,12 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -712,10 +786,12 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [56, 80])
...@@ -738,7 +814,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -738,7 +814,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.17]) # @pytest.mark.parametrize("dropout_p", [0.17])
def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked): def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, mha_type, dtype, kvpacked
):
if ( if (
max(seqlen_q, seqlen_k) >= 2048 max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
...@@ -747,10 +825,11 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d ...@@ -747,10 +825,11 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 16 batch_size = 13
nheads = 9 nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0 assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if kvpacked: if kvpacked:
kv = torch.randn( kv = torch.randn(
...@@ -766,15 +845,23 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d ...@@ -766,15 +845,23 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
if kvpacked: if kvpacked:
out, lse, S_dmask = flash_attn_kvpacked_func( out, lse, S_dmask = flash_attn_kvpacked_func(
q, kv, dropout_p, return_attn_probs=True, causal=causal q, kv, dropout_p, causal=causal, window_size=window_size, return_attn_probs=True
) )
else: else:
out, lse, S_dmask = flash_attn_func( out, lse, S_dmask = flash_attn_func(
q, k, v, dropout_p, return_attn_probs=True, causal=causal q, k, v, dropout_p, causal=causal, window_size=window_size, return_attn_probs=True
) )
if dropout_p > 0.0: if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask_converted = convert_flash_attn_S_to_softmax(
S_dmask, seqlen_q, seqlen_k, None, None, d, dropout_p > 0.0, causal=causal S_dmask,
seqlen_q,
seqlen_k,
None,
None,
d,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
) )
dropout_mask = S_dmask_converted >= 0 dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs() attn_unnorm = S_dmask_converted.abs()
...@@ -785,16 +872,33 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d ...@@ -785,16 +872,33 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k)
attn = normalize_flash_attn_S( attn = normalize_flash_attn_S(
attn_unnorm, q, k_rep, v_rep, None, None, dropout_p > 0.0, causal=causal attn_unnorm,
q,
k_rep,
v_rep,
None,
None,
dropout_p > 0.0,
causal=causal,
window_size=window_size,
) )
dropout_fraction = get_dropout_fraction(dropout_mask, None, None, causal=causal).item() dropout_fraction = get_dropout_fraction(
dropout_mask, None, None, causal=causal, window_size=window_size
).item()
print(f"Actual dropout fraction: {dropout_fraction}") print(f"Actual dropout fraction: {dropout_fraction}")
else: else:
dropout_mask = None dropout_mask = None
if kvpacked: if kvpacked:
out_ref, attn_ref = attention_kvpacked_ref( out_ref, attn_ref = attention_kvpacked_ref(
q, kv, None, None, dropout_p, dropout_mask, causal=causal q,
kv,
None,
None,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
) )
out_pt, attn_pt = attention_kvpacked_ref( out_pt, attn_pt = attention_kvpacked_ref(
q, q,
...@@ -804,12 +908,21 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d ...@@ -804,12 +908,21 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
else: else:
out_ref, attn_ref = attention_ref( out_ref, attn_ref = attention_ref(
q, k, v, None, None, dropout_p, dropout_mask, causal=causal q,
k,
v,
None,
None,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
) )
out_pt, attn_pt = attention_ref( out_pt, attn_pt = attention_ref(
q, q,
...@@ -820,6 +933,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d ...@@ -820,6 +933,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
...@@ -886,7 +1000,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d ...@@ -886,7 +1000,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
if dropout_p > 0.0: if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
assert abs(dropout_fraction - dropout_p) <= 0.01 assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
...@@ -900,10 +1014,12 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d ...@@ -900,10 +1014,12 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mqa"]) # @pytest.mark.parametrize('mha_type', ["mqa"])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [True]) # @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seqlen_q,seqlen_k", "seqlen_q,seqlen_k",
...@@ -925,7 +1041,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d ...@@ -925,7 +1041,7 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_output( def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked seqlen_q, seqlen_k, d, dropout_p, causal, local, mha_type, dtype, kvpacked
): ):
if ( if (
max(seqlen_q, seqlen_k) >= 2048 max(seqlen_q, seqlen_k) >= 2048
...@@ -935,10 +1051,11 @@ def test_flash_attn_varlen_output( ...@@ -935,10 +1051,11 @@ def test_flash_attn_varlen_output(
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 16 batch_size = 13
nheads = 9 nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0 assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if kvpacked: if kvpacked:
kv = torch.randn( kv = torch.randn(
...@@ -980,6 +1097,7 @@ def test_flash_attn_varlen_output( ...@@ -980,6 +1097,7 @@ def test_flash_attn_varlen_output(
dropout_p, dropout_p,
return_attn_probs=True, return_attn_probs=True,
causal=causal, causal=causal,
window_size=window_size,
) )
else: else:
( (
...@@ -1008,6 +1126,7 @@ def test_flash_attn_varlen_output( ...@@ -1008,6 +1126,7 @@ def test_flash_attn_varlen_output(
dropout_p, dropout_p,
return_attn_probs=True, return_attn_probs=True,
causal=causal, causal=causal,
window_size=window_size,
) )
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
if dropout_p > 0.0: if dropout_p > 0.0:
...@@ -1020,6 +1139,7 @@ def test_flash_attn_varlen_output( ...@@ -1020,6 +1139,7 @@ def test_flash_attn_varlen_output(
d, d,
dropout_p > 0.0, dropout_p > 0.0,
causal=causal, causal=causal,
window_size=window_size,
) )
dropout_mask = S_dmask_converted >= 0 dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs() attn_unnorm = S_dmask_converted.abs()
...@@ -1038,9 +1158,14 @@ def test_flash_attn_varlen_output( ...@@ -1038,9 +1158,14 @@ def test_flash_attn_varlen_output(
key_padding_mask, key_padding_mask,
dropout_p > 0.0, dropout_p > 0.0,
causal=causal, causal=causal,
window_size=window_size,
) )
dropout_fraction = get_dropout_fraction( dropout_fraction = get_dropout_fraction(
dropout_mask, query_padding_mask, key_padding_mask, causal=causal dropout_mask,
query_padding_mask,
key_padding_mask,
causal=causal,
window_size=window_size,
).item() ).item()
print(f"Actual dropout fraction: {dropout_fraction}") print(f"Actual dropout fraction: {dropout_fraction}")
else: else:
...@@ -1048,7 +1173,14 @@ def test_flash_attn_varlen_output( ...@@ -1048,7 +1173,14 @@ def test_flash_attn_varlen_output(
if kvpacked: if kvpacked:
out_ref, attn_ref = attention_kvpacked_ref( out_ref, attn_ref = attention_kvpacked_ref(
q, kv, query_padding_mask, key_padding_mask, dropout_p, dropout_mask, causal=causal q,
kv,
query_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
) )
out_pt, attn_pt = attention_kvpacked_ref( out_pt, attn_pt = attention_kvpacked_ref(
q, q,
...@@ -1058,12 +1190,21 @@ def test_flash_attn_varlen_output( ...@@ -1058,12 +1190,21 @@ def test_flash_attn_varlen_output(
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
else: else:
out_ref, attn_ref = attention_ref( out_ref, attn_ref = attention_ref(
q, k, v, query_padding_mask, key_padding_mask, dropout_p, dropout_mask, causal=causal q,
k,
v,
query_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
) )
out_pt, attn_pt = attention_ref( out_pt, attn_pt = attention_ref(
q, q,
...@@ -1074,6 +1215,7 @@ def test_flash_attn_varlen_output( ...@@ -1074,6 +1215,7 @@ def test_flash_attn_varlen_output(
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
window_size=window_size,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
...@@ -1142,7 +1284,7 @@ def test_flash_attn_varlen_output( ...@@ -1142,7 +1284,7 @@ def test_flash_attn_varlen_output(
if dropout_p > 0.0: if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
assert abs(dropout_fraction - dropout_p) <= 0.01 assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
...@@ -1152,8 +1294,10 @@ def test_flash_attn_varlen_output( ...@@ -1152,8 +1294,10 @@ def test_flash_attn_varlen_output(
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [56, 80])
...@@ -1176,7 +1320,7 @@ def test_flash_attn_varlen_output( ...@@ -1176,7 +1320,7 @@ def test_flash_attn_varlen_output(
], ],
) )
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
if ( if (
max(seqlen_q, seqlen_k) >= 2048 max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
...@@ -1188,13 +1332,16 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ...@@ -1188,13 +1332,16 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
causal = True causal = True
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 16 batch_size = 13
nheads = 9 nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
out = flash_attn_func(q, k, v, 0.0, causal=causal) out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size)
out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal) out_ref, attn_ref = attention_ref(
q, k, v, None, None, 0.0, None, causal=causal, window_size=window_size
)
out_pt, attn_pt = attention_ref( out_pt, attn_pt = attention_ref(
q, q,
k, k,
...@@ -1204,6 +1351,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ...@@ -1204,6 +1351,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
0.0, 0.0,
None, None,
causal=causal, causal=causal,
window_size=window_size,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
...@@ -1256,12 +1404,14 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ...@@ -1256,12 +1404,14 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128]) # @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True]) @pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [True]) # @pytest.mark.parametrize("swap_sq_sk", [True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1280,7 +1430,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ...@@ -1280,7 +1430,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
], ],
) )
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
if ( if (
max(seqlen_q, seqlen_k) >= 2048 max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
...@@ -1292,8 +1442,9 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ...@@ -1292,8 +1442,9 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
causal = True causal = True
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 16 batch_size = 13
nheads = 9 nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
...@@ -1324,10 +1475,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ...@@ -1324,10 +1475,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
max_seqlen_k, max_seqlen_k,
0.0, 0.0,
causal=causal, causal=causal,
window_size=window_size,
) )
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
out_ref, attn_ref = attention_ref( out_ref, attn_ref = attention_ref(
q, k, v, query_padding_mask, key_padding_mask, 0.0, None, causal=causal q,
k,
v,
query_padding_mask,
key_padding_mask,
0.0,
None,
causal=causal,
window_size=window_size,
) )
out_pt, attn_pt = attention_ref( out_pt, attn_pt = attention_ref(
q, q,
...@@ -1338,6 +1498,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ...@@ -1338,6 +1498,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
0.0, 0.0,
None, None,
causal=causal, causal=causal,
window_size=window_size,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
...@@ -1393,6 +1554,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ...@@ -1393,6 +1554,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16]) # @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
...@@ -1418,7 +1581,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ...@@ -1418,7 +1581,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
], ],
) )
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
if swap_sq_sk: if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda" device = "cuda"
...@@ -1426,11 +1589,16 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): ...@@ -1426,11 +1589,16 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 1 batch_size = 1
nheads = 12 nheads = 12
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
out, lse, _ = flash_attn_func(q, k, v, 0.0, causal=causal, return_attn_probs=True) out, lse, _ = flash_attn_func(
out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal) q, k, v, 0.0, causal=causal, window_size=window_size, return_attn_probs=True
)
out_ref, attn_ref = attention_ref(
q, k, v, None, None, 0.0, None, causal=causal, window_size=window_size
)
out_pt, attn_pt = attention_ref( out_pt, attn_pt = attention_ref(
q, q,
k, k,
...@@ -1440,6 +1608,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): ...@@ -1440,6 +1608,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
0.0, 0.0,
None, None,
causal=causal, causal=causal,
window_size=window_size,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
...@@ -1498,6 +1667,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): ...@@ -1498,6 +1667,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
# @pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("new_kv", [False, True]) @pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [True]) # @pytest.mark.parametrize("new_kv", [True])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
...@@ -1506,7 +1677,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): ...@@ -1506,7 +1677,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
# @pytest.mark.parametrize("rotary_interleaved", [False]) # @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [56, 80])
...@@ -1536,6 +1707,7 @@ def test_flash_attn_kvcache( ...@@ -1536,6 +1707,7 @@ def test_flash_attn_kvcache(
rotary_interleaved, rotary_interleaved,
seqlen_new_eq_seqlen_q, seqlen_new_eq_seqlen_q,
causal, causal,
local,
new_kv, new_kv,
mha_type, mha_type,
num_splits, num_splits,
...@@ -1554,6 +1726,7 @@ def test_flash_attn_kvcache( ...@@ -1554,6 +1726,7 @@ def test_flash_attn_kvcache(
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0 assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item()
if new_kv: if new_kv:
...@@ -1566,7 +1739,7 @@ def test_flash_attn_kvcache( ...@@ -1566,7 +1739,7 @@ def test_flash_attn_kvcache(
cache_seqlens = torch.randint( cache_seqlens = torch.randint(
0, 0,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(seqlen_k - (seqlen_q if causal and rotary_dim > 1 else seqlen_new) + 1) (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1)
if new_kv if new_kv
else (seqlen_k + 1), else (seqlen_k + 1),
(batch_size,), (batch_size,),
...@@ -1578,7 +1751,7 @@ def test_flash_attn_kvcache( ...@@ -1578,7 +1751,7 @@ def test_flash_attn_kvcache(
angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype) cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype)
if causal: if causal or local:
q_ro = apply_rotary_emb( q_ro = apply_rotary_emb(
q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
) )
...@@ -1624,11 +1797,14 @@ def test_flash_attn_kvcache( ...@@ -1624,11 +1797,14 @@ def test_flash_attn_kvcache(
sin, sin,
cache_seqlens, cache_seqlens,
causal=causal, causal=causal,
window_size=window_size,
rotary_interleaved=rotary_interleaved, rotary_interleaved=rotary_interleaved,
num_splits=num_splits, num_splits=num_splits,
) )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal) # out = flash_attn_with_kvcache(
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal) # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True) # m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d)) # s_tmp = torch.exp((qk - m) / math.sqrt(d))
...@@ -1637,7 +1813,15 @@ def test_flash_attn_kvcache( ...@@ -1637,7 +1813,15 @@ def test_flash_attn_kvcache(
# probs = torch.softmax(qk, dim=-1) # probs = torch.softmax(qk, dim=-1)
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
out_ref, _ = attention_ref( out_ref, _ = attention_ref(
q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal q_ro,
k_cache_rep,
v_cache_rep,
None,
key_padding_mask,
0.0,
None,
causal=causal,
window_size=window_size,
) )
out_pt, _ = attention_ref( out_pt, _ = attention_ref(
q_ro, q_ro,
...@@ -1648,6 +1832,7 @@ def test_flash_attn_kvcache( ...@@ -1648,6 +1832,7 @@ def test_flash_attn_kvcache(
0.0, 0.0,
None, None,
causal=causal, causal=causal,
window_size=window_size,
upcast=False, upcast=False,
reorder_ops=True, reorder_ops=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment