"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "51a8d23392fa0406976a0a782035ccf70922934b"
Commit 73265458 authored by Tri Dao's avatar Tri Dao
Browse files

Implement deterministic backward (thanks to Meituan)

parent 2c7d7b73
...@@ -83,7 +83,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func ...@@ -83,7 +83,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
```python ```python
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None): window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
...@@ -99,6 +99,8 @@ Arguments: ...@@ -99,6 +99,8 @@ Arguments:
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j. the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
""" """
...@@ -106,7 +108,7 @@ Return: ...@@ -106,7 +108,7 @@ Return:
```python ```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None): window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
...@@ -128,6 +130,8 @@ Arguments: ...@@ -128,6 +130,8 @@ Arguments:
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|) (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j. is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
""" """
...@@ -269,10 +273,12 @@ Implement sliding window attention (i.e., local attention). Thanks to [Mistral ...@@ -269,10 +273,12 @@ Implement sliding window attention (i.e., local attention). Thanks to [Mistral
AI](https://mistral.ai/) and in particular Timothée Lacroix for this AI](https://mistral.ai/) and in particular Timothée Lacroix for this
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model. contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
### 2.4: ALiBi (attention with linear bias) ### 2.4: ALiBi (attention with linear bias), deterministic backward pass.
Implement ALiBi (Press et el., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution. Implement ALiBi (Press et el., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
## Performance ## Performance
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
......
...@@ -150,7 +150,8 @@ void set_params_dgrad(Flash_bwd_params &params, ...@@ -150,7 +150,8 @@ void set_params_dgrad(Flash_bwd_params &params,
float p_dropout, float p_dropout,
float softmax_scale, float softmax_scale,
int window_size_left, int window_size_left,
int window_size_right) { int window_size_right,
bool deterministic) {
set_params_fprop(params, set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
...@@ -192,6 +193,8 @@ void set_params_dgrad(Flash_bwd_params &params, ...@@ -192,6 +193,8 @@ void set_params_dgrad(Flash_bwd_params &params,
// Softmax sum // Softmax sum
params.dsoftmax_sum = dsoftmax_sum_d; params.dsoftmax_sum = dsoftmax_sum_d;
params.deterministic = deterministic;
} }
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) { void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
...@@ -618,8 +621,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -618,8 +621,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
params.alibi_slopes_ptr = nullptr; params.alibi_slopes_ptr = nullptr;
} }
auto stream = at::cuda::getCurrentCUDAStream().stream(); if (max_seqlen_k > 0) {
run_mha_fwd(params, stream); auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
at::Tensor out_padded = out; at::Tensor out_padded = out;
if (head_size_og % 8 != 0) { if (head_size_og % 8 != 0) {
...@@ -668,6 +677,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -668,6 +677,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
const bool deterministic,
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) { c10::optional<at::Tensor> &rng_state) {
...@@ -783,7 +793,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -783,7 +793,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
at::Tensor dq_accum; at::Tensor dq_accum;
at::Tensor dk_accum, dv_accum; at::Tensor dk_accum, dv_accum;
if (loop) { if (loop) {
dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); if (!deterministic) {
dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
} else {
const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
}
// dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
// dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
} }
...@@ -819,7 +834,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -819,7 +834,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
p_dropout, p_dropout,
softmax_scale, softmax_scale,
window_size_left, window_size_left,
window_size_right); window_size_right,
deterministic);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
auto launch = &run_mha_bwd; auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true); // launch(params, stream, /*configure=*/true);
...@@ -857,8 +874,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -857,8 +874,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
launch(params, stream, /*configure=*/false); launch(params, stream, /*configure=*/false);
} else { } else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk.zero_(); dk_expanded.zero_();
dv.zero_(); dv_expanded.zero_();
softmax_d.zero_(); softmax_d.zero_();
} }
...@@ -897,6 +914,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -897,6 +914,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
const bool deterministic,
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) { c10::optional<at::Tensor> &rng_state) {
...@@ -1025,7 +1043,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1025,7 +1043,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
// allowed to do. So we won't have to do any bound checking, and performance should stay the same. // allowed to do. So we won't have to do any bound checking, and performance should stay the same.
dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); if (!deterministic) {
dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
} else {
const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
}
} }
at::Tensor dk_expanded, dv_expanded; at::Tensor dk_expanded, dv_expanded;
...@@ -1064,7 +1087,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1064,7 +1087,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
p_dropout, p_dropout,
softmax_scale, softmax_scale,
window_size_left, window_size_left,
window_size_right); window_size_right,
deterministic);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
auto launch = &run_mha_bwd; auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true); // launch(params, stream, /*configure=*/true);
...@@ -1098,7 +1123,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1098,7 +1123,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
params.alibi_slopes_ptr = nullptr; params.alibi_slopes_ptr = nullptr;
} }
launch(params, stream, /*configure=*/false); if (max_seqlen_q > 0) {
launch(params, stream, /*configure=*/false);
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded.zero_();
dv_expanded.zero_();
softmax_d.zero_();
}
// For MQA/GQA we need to sum dK and dV across the groups // For MQA/GQA we need to sum dK and dV across the groups
if (num_heads_k != num_heads) { if (num_heads_k != num_heads) {
......
...@@ -172,6 +172,9 @@ struct Flash_bwd_params : public Flash_fwd_params { ...@@ -172,6 +172,9 @@ struct Flash_bwd_params : public Flash_fwd_params {
// The pointer to the softmax d sum. // The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum; void *__restrict__ dsoftmax_sum;
bool deterministic;
index_t dq_accum_split_stride;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -230,7 +230,7 @@ inline __device__ void clear_dKVaccum(const Params &params) { ...@@ -230,7 +230,7 @@ inline __device__ void clear_dKVaccum(const Params &params) {
// Convert dQ from dQaccum (in float) to fp16/bf16. // Convert dQ from dQaccum (in float) to fp16/bf16.
// This is used in the case where we want to parallelize the backward across seqlen_k. // This is used in the case where we want to parallelize the backward across seqlen_k.
template<typename Kernel_traits, typename Params> template<typename Kernel_traits, typename Params>
inline __device__ void convert_dQ(const Params &params) { inline __device__ void convert_dQ(const Params &params, const int nsplits) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum; using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t; using index_t = typename Kernel_traits::index_t;
...@@ -285,11 +285,15 @@ inline __device__ void convert_dQ(const Params &params) { ...@@ -285,11 +285,15 @@ inline __device__ void convert_dQ(const Params &params) {
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum); clear(acc_dq);
#pragma unroll for (int s = 0; s < nsplits; ++s) {
for (int i = 0; i < size(acc_dq); ++i) { cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
acc_dq(i) = tdQrdQaccum(i) * params.scale_softmax_rp_dropout; #pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
} }
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
// Convert acc_dq from fp32 to fp16 // Convert acc_dq from fp32 to fp16
Tensor rdQ = flash::convert_type<Element>(acc_dq); Tensor rdQ = flash::convert_type<Element>(acc_dq);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
...@@ -466,7 +470,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -466,7 +470,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
+ (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q
+ (m_block_max - 1) * kBlockM; + (m_block_max - 1) * kBlockM;
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
...@@ -715,7 +721,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -715,7 +721,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tdKsQt.data() = tdKsQt.data() + size(sQ); tdKsQt.data() = tdKsQt.data() + size(sQ);
} }
if (!Is_first && !Seq_parallel) { __syncthreads(); } if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); }
if (Kernel_traits::Is_V_in_regs) { if (Kernel_traits::Is_V_in_regs) {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
...@@ -1604,13 +1610,15 @@ inline __device__ void compute_dq_dk_dv(const Params &params) { ...@@ -1604,13 +1610,15 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) { inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
const int n_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
const int bidb = blockIdx.y; const int bidb = blockIdx.y;
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.z; const int bidh = blockIdx.z;
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block); // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
}
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -35,8 +35,8 @@ __global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params pa ...@@ -35,8 +35,8 @@ __global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params pa
} }
template<typename Kernel_traits> template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) { __global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params, const int nsplits) {
flash::convert_dQ<Kernel_traits>(params); flash::convert_dQ<Kernel_traits>(params, nsplits);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
...@@ -49,9 +49,18 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -49,9 +49,18 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h); dim3 grid_m(num_m_block, params.b, params.h);
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
dim3 grid_n(num_n_block, params.b, params.h); int gridDimx = num_n_block;
if (params.deterministic) {
auto dprops = at::cuda::getCurrentDeviceProperties();
gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
}
dim3 grid_n(gridDimx, params.b, params.h);
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params); if (!params.deterministic) {
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
} else {
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
...@@ -69,6 +78,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -69,6 +78,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false // If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>; auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) { if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute( C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
...@@ -86,7 +96,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -86,7 +96,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
C10_CUDA_CHECK(cudaFuncSetAttribute( C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
} }
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params); kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
......
...@@ -52,6 +52,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -52,6 +52,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false // If Is_local, set Is_causal to false
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>; auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>; // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) { if (smem_size >= 48 * 1024) {
......
...@@ -122,6 +122,7 @@ def _flash_attn_backward( ...@@ -122,6 +122,7 @@ def _flash_attn_backward(
causal, causal,
window_size, window_size,
alibi_slopes, alibi_slopes,
deterministic,
rng_state=None, rng_state=None,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
...@@ -143,6 +144,7 @@ def _flash_attn_backward( ...@@ -143,6 +144,7 @@ def _flash_attn_backward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
deterministic,
None, None,
rng_state, rng_state,
) )
...@@ -168,6 +170,7 @@ def _flash_attn_varlen_backward( ...@@ -168,6 +170,7 @@ def _flash_attn_varlen_backward(
causal, causal,
window_size, window_size,
alibi_slopes, alibi_slopes,
deterministic,
rng_state=None, rng_state=None,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
...@@ -194,6 +197,7 @@ def _flash_attn_varlen_backward( ...@@ -194,6 +197,7 @@ def _flash_attn_varlen_backward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
deterministic,
None, None,
rng_state, rng_state,
) )
...@@ -205,7 +209,15 @@ def _flash_attn_varlen_backward( ...@@ -205,7 +209,15 @@ def _flash_attn_varlen_backward(
class FlashAttnQKVPackedFunc(torch.autograd.Function): class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax ctx,
qkv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
...@@ -226,6 +238,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -226,6 +238,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -248,10 +261,11 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -248,10 +261,11 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
) )
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...@@ -266,6 +280,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -266,6 +280,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
causal, causal,
window_size, window_size,
alibi_slopes, alibi_slopes,
deterministic,
return_softmax, return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
...@@ -292,6 +307,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -292,6 +307,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -318,16 +334,26 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -318,16 +334,26 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
) )
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax ctx,
q,
kv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -348,6 +374,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -348,6 +374,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -371,11 +398,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -371,11 +398,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...@@ -393,6 +421,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -393,6 +421,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
causal, causal,
window_size, window_size,
alibi_slopes, alibi_slopes,
deterministic,
return_softmax, return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
...@@ -422,6 +451,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -422,6 +451,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -449,17 +479,28 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -449,17 +479,28 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax ctx,
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -480,6 +521,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -480,6 +521,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -501,12 +543,13 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -501,12 +543,13 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function): class FlashAttnVarlenFunc(torch.autograd.Function):
...@@ -525,6 +568,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -525,6 +568,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
causal, causal,
window_size, window_size,
alibi_slopes, alibi_slopes,
deterministic,
return_softmax, return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
...@@ -554,6 +598,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -554,6 +598,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -579,12 +624,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -579,12 +624,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes, ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func( def flash_attn_qkvpacked_func(
...@@ -594,6 +640,7 @@ def flash_attn_qkvpacked_func( ...@@ -594,6 +640,7 @@ def flash_attn_qkvpacked_func(
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None, alibi_slopes=None,
deterministic=False,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -615,6 +662,8 @@ def flash_attn_qkvpacked_func( ...@@ -615,6 +662,8 @@ def flash_attn_qkvpacked_func(
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j. the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -628,7 +677,14 @@ def flash_attn_qkvpacked_func( ...@@ -628,7 +677,14 @@ def flash_attn_qkvpacked_func(
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnQKVPackedFunc.apply( return FlashAttnQKVPackedFunc.apply(
qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs qkv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
) )
...@@ -640,6 +696,7 @@ def flash_attn_kvpacked_func( ...@@ -640,6 +696,7 @@ def flash_attn_kvpacked_func(
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None, alibi_slopes=None,
deterministic=False,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -678,6 +735,8 @@ def flash_attn_kvpacked_func( ...@@ -678,6 +735,8 @@ def flash_attn_kvpacked_func(
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|) (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j. is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -691,7 +750,15 @@ def flash_attn_kvpacked_func( ...@@ -691,7 +750,15 @@ def flash_attn_kvpacked_func(
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnKVPackedFunc.apply( return FlashAttnKVPackedFunc.apply(
q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs q,
kv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
) )
...@@ -704,6 +771,7 @@ def flash_attn_func( ...@@ -704,6 +771,7 @@ def flash_attn_func(
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None, alibi_slopes=None,
deterministic=False,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -740,6 +808,8 @@ def flash_attn_func( ...@@ -740,6 +808,8 @@ def flash_attn_func(
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|) (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j. is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -753,7 +823,16 @@ def flash_attn_func( ...@@ -753,7 +823,16 @@ def flash_attn_func(
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnFunc.apply( return FlashAttnFunc.apply(
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
) )
...@@ -766,6 +845,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -766,6 +845,7 @@ def flash_attn_varlen_qkvpacked_func(
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None, alibi_slopes=None,
deterministic=False,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -790,6 +870,8 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -790,6 +870,8 @@ def flash_attn_varlen_qkvpacked_func(
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j. is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -811,6 +893,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -811,6 +893,7 @@ def flash_attn_varlen_qkvpacked_func(
causal, causal,
window_size, window_size,
alibi_slopes, alibi_slopes,
deterministic,
return_attn_probs, return_attn_probs,
) )
...@@ -827,6 +910,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -827,6 +910,7 @@ def flash_attn_varlen_kvpacked_func(
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None, alibi_slopes=None,
deterministic=False,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -871,6 +955,8 @@ def flash_attn_varlen_kvpacked_func( ...@@ -871,6 +955,8 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|) (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j. is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -895,6 +981,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -895,6 +981,7 @@ def flash_attn_varlen_kvpacked_func(
causal, causal,
window_size, window_size,
alibi_slopes, alibi_slopes,
deterministic,
return_attn_probs, return_attn_probs,
) )
...@@ -912,6 +999,7 @@ def flash_attn_varlen_func( ...@@ -912,6 +999,7 @@ def flash_attn_varlen_func(
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None, alibi_slopes=None,
deterministic=False,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -954,6 +1042,8 @@ def flash_attn_varlen_func( ...@@ -954,6 +1042,8 @@ def flash_attn_varlen_func(
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|) (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j. is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -979,6 +1069,7 @@ def flash_attn_varlen_func( ...@@ -979,6 +1069,7 @@ def flash_attn_varlen_func(
causal, causal,
window_size, window_size,
alibi_slopes, alibi_slopes,
deterministic,
return_attn_probs, return_attn_probs,
) )
......
...@@ -566,10 +566,12 @@ def get_dropout_fraction( ...@@ -566,10 +566,12 @@ def get_dropout_fraction(
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16]) # @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True]) @pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True]) # @pytest.mark.parametrize("alibi", [False])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True]) # @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
...@@ -578,16 +580,16 @@ def get_dropout_fraction( ...@@ -578,16 +580,16 @@ def get_dropout_fraction(
# @pytest.mark.parametrize("d", [64]) # @pytest.mark.parametrize("d", [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize("seqlen", [97]) # @pytest.mark.parametrize("seqlen", [512])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.0]) # @pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype): def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 4
nheads = 9 nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
qkv = torch.randn( qkv = torch.randn(
...@@ -604,6 +606,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype) ...@@ -604,6 +606,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype)
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True, return_attn_probs=True,
) )
if dropout_p > 0.0: if dropout_p > 0.0:
...@@ -712,6 +715,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype) ...@@ -712,6 +715,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype)
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True]) @pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True]) # @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
...@@ -725,7 +730,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype) ...@@ -725,7 +730,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype)
# @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype): def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = "cuda" device = "cuda"
...@@ -760,6 +765,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, ...@@ -760,6 +765,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True, return_attn_probs=True,
) )
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
...@@ -859,6 +865,8 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, ...@@ -859,6 +865,8 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
# @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True]) @pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True]) # @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
...@@ -890,7 +898,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, ...@@ -890,7 +898,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.17]) # @pytest.mark.parametrize("dropout_p", [0.17])
def test_flash_attn_output( def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, mha_type, dtype, kvpacked seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
): ):
if ( if (
max(seqlen_q, seqlen_k) >= 2048 max(seqlen_q, seqlen_k) >= 2048
...@@ -900,7 +908,7 @@ def test_flash_attn_output( ...@@ -900,7 +908,7 @@ def test_flash_attn_output(
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 4
nheads = 9 nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0 assert nheads % nheads_k == 0
...@@ -931,6 +939,7 @@ def test_flash_attn_output( ...@@ -931,6 +939,7 @@ def test_flash_attn_output(
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True, return_attn_probs=True,
) )
else: else:
...@@ -942,6 +951,7 @@ def test_flash_attn_output( ...@@ -942,6 +951,7 @@ def test_flash_attn_output(
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True, return_attn_probs=True,
) )
if dropout_p > 0.0: if dropout_p > 0.0:
...@@ -1114,6 +1124,8 @@ def test_flash_attn_output( ...@@ -1114,6 +1124,8 @@ def test_flash_attn_output(
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mqa"]) # @pytest.mark.parametrize('mha_type', ["mqa"])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True]) @pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True]) # @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
...@@ -1143,7 +1155,7 @@ def test_flash_attn_output( ...@@ -1143,7 +1155,7 @@ def test_flash_attn_output(
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_output( def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, mha_type, dtype, kvpacked seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
): ):
if ( if (
max(seqlen_q, seqlen_k) >= 2048 max(seqlen_q, seqlen_k) >= 2048
...@@ -1153,7 +1165,7 @@ def test_flash_attn_varlen_output( ...@@ -1153,7 +1165,7 @@ def test_flash_attn_varlen_output(
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 4
nheads = 9 nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0 assert nheads % nheads_k == 0
...@@ -1207,6 +1219,7 @@ def test_flash_attn_varlen_output( ...@@ -1207,6 +1219,7 @@ def test_flash_attn_varlen_output(
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True, return_attn_probs=True,
) )
else: else:
...@@ -1237,6 +1250,7 @@ def test_flash_attn_varlen_output( ...@@ -1237,6 +1250,7 @@ def test_flash_attn_varlen_output(
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True, return_attn_probs=True,
) )
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
...@@ -1675,6 +1689,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1675,6 +1689,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16]) # @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True]) @pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True]) # @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
...@@ -1704,7 +1720,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1704,7 +1720,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
], ],
) )
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, dtype): def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype):
if swap_sq_sk: if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda" device = "cuda"
...@@ -1729,6 +1745,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al ...@@ -1729,6 +1745,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True, return_attn_probs=True,
) )
out_ref, attn_ref = attention_ref( out_ref, attn_ref = attention_ref(
...@@ -2224,3 +2241,152 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): ...@@ -2224,3 +2241,152 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
assert not q.grad.isnan().any() assert not q.grad.isnan().any()
assert not k.grad.isnan().any() assert not k.grad.isnan().any()
assert not v.grad.isnan().any() assert not v.grad.isnan().any()
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 239),
(3, 799),
(127, 512),
(127, 513),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(1023, 1024),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
assert torch.equal(dv, dv0)
assert torch.equal(dk, dk0)
assert torch.equal(dq, dq0)
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 239),
(3, 799),
(127, 512),
(127, 513),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(1023, 1024),
],
)
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 2
nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
out = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
causal=causal,
window_size=window_size,
deterministic=True,
)
g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
assert torch.equal(dv, dv)
assert torch.equal(dk, dk)
assert torch.equal(dq, dq)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment