Commit 73265458 authored by Tri Dao's avatar Tri Dao
Browse files

Implement deterministic backward (thanks to Meituan)

parent 2c7d7b73
......@@ -83,7 +83,7 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
```python
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None):
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
......@@ -99,6 +99,8 @@ Arguments:
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
......@@ -106,7 +108,7 @@ Return:
```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None):
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
......@@ -128,6 +130,8 @@ Arguments:
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
......@@ -269,10 +273,12 @@ Implement sliding window attention (i.e., local attention). Thanks to [Mistral
AI](https://mistral.ai/) and in particular Timothée Lacroix for this
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
### 2.4: ALiBi (attention with linear bias)
### 2.4: ALiBi (attention with linear bias), deterministic backward pass.
Implement ALiBi (Press et el., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution.
## Performance
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
......
......@@ -150,7 +150,8 @@ void set_params_dgrad(Flash_bwd_params &params,
float p_dropout,
float softmax_scale,
int window_size_left,
int window_size_right) {
int window_size_right,
bool deterministic) {
set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
......@@ -192,6 +193,8 @@ void set_params_dgrad(Flash_bwd_params &params,
// Softmax sum
params.dsoftmax_sum = dsoftmax_sum_d;
params.deterministic = deterministic;
}
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
......@@ -618,8 +621,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
params.alibi_slopes_ptr = nullptr;
}
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
at::Tensor out_padded = out;
if (head_size_og % 8 != 0) {
......@@ -668,6 +677,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const bool is_causal,
const int window_size_left,
int window_size_right,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) {
......@@ -783,7 +793,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
at::Tensor dq_accum;
at::Tensor dk_accum, dv_accum;
if (loop) {
dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
if (!deterministic) {
dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
} else {
const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
}
// dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
// dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
}
......@@ -819,7 +834,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
p_dropout,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
deterministic);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true);
......@@ -857,8 +874,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
launch(params, stream, /*configure=*/false);
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk.zero_();
dv.zero_();
dk_expanded.zero_();
dv_expanded.zero_();
softmax_d.zero_();
}
......@@ -897,6 +914,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const bool is_causal,
const int window_size_left,
int window_size_right,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) {
......@@ -1025,7 +1043,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
// allowed to do. So we won't have to do any bound checking, and performance should stay the same.
dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
if (!deterministic) {
dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
} else {
const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
}
}
at::Tensor dk_expanded, dv_expanded;
......@@ -1064,7 +1087,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
p_dropout,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
deterministic);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
auto launch = &run_mha_bwd;
// launch(params, stream, /*configure=*/true);
......@@ -1098,7 +1123,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
params.alibi_slopes_ptr = nullptr;
}
launch(params, stream, /*configure=*/false);
if (max_seqlen_q > 0) {
launch(params, stream, /*configure=*/false);
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dk_expanded.zero_();
dv_expanded.zero_();
softmax_d.zero_();
}
// For MQA/GQA we need to sum dK and dV across the groups
if (num_heads_k != num_heads) {
......
......@@ -172,6 +172,9 @@ struct Flash_bwd_params : public Flash_fwd_params {
// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
bool deterministic;
index_t dq_accum_split_stride;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -230,7 +230,7 @@ inline __device__ void clear_dKVaccum(const Params &params) {
// Convert dQ from dQaccum (in float) to fp16/bf16.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<typename Kernel_traits, typename Params>
inline __device__ void convert_dQ(const Params &params) {
inline __device__ void convert_dQ(const Params &params, const int nsplits) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
......@@ -285,11 +285,15 @@ inline __device__ void convert_dQ(const Params &params) {
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) {
acc_dq(i) = tdQrdQaccum(i) * params.scale_softmax_rp_dropout;
clear(acc_dq);
for (int s = 0; s < nsplits; ++s) {
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
}
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
// Convert acc_dq from fp32 to fp16
Tensor rdQ = flash::convert_type<Element>(acc_dq);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
......@@ -466,7 +470,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
+ (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q
+ (m_block_max - 1) * kBlockM;
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
......@@ -715,7 +721,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tdKsQt.data() = tdKsQt.data() + size(sQ);
}
if (!Is_first && !Seq_parallel) { __syncthreads(); }
if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); }
if (Kernel_traits::Is_V_in_regs) {
// Clear the smem tiles to account for predicated off loads
......@@ -1604,13 +1610,15 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
const int n_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -35,8 +35,8 @@ __global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params pa
}
template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) {
flash::convert_dQ<Kernel_traits>(params);
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params, const int nsplits) {
flash::convert_dQ<Kernel_traits>(params, nsplits);
}
template<typename Kernel_traits>
......@@ -49,9 +49,18 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h);
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
dim3 grid_n(num_n_block, params.b, params.h);
int gridDimx = num_n_block;
if (params.deterministic) {
auto dprops = at::cuda::getCurrentDeviceProperties();
gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
}
dim3 grid_n(gridDimx, params.b, params.h);
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
if (!params.deterministic) {
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
} else {
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
......@@ -69,6 +78,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
......@@ -86,7 +96,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
}
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params);
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
......
......@@ -52,6 +52,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
......
......@@ -122,6 +122,7 @@ def _flash_attn_backward(
causal,
window_size,
alibi_slopes,
deterministic,
rng_state=None,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
......@@ -143,6 +144,7 @@ def _flash_attn_backward(
causal,
window_size[0],
window_size[1],
deterministic,
None,
rng_state,
)
......@@ -168,6 +170,7 @@ def _flash_attn_varlen_backward(
causal,
window_size,
alibi_slopes,
deterministic,
rng_state=None,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
......@@ -194,6 +197,7 @@ def _flash_attn_varlen_backward(
causal,
window_size[0],
window_size[1],
deterministic,
None,
rng_state,
)
......@@ -205,7 +209,15 @@ def _flash_attn_varlen_backward(
class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
ctx,
qkv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
......@@ -226,6 +238,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
......@@ -248,10 +261,11 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.causal,
ctx.window_size,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
......@@ -266,6 +280,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
):
if softmax_scale is None:
......@@ -292,6 +307,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
......@@ -318,16 +334,26 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.causal,
ctx.window_size,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
ctx,
q,
kv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -348,6 +374,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
......@@ -371,11 +398,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.causal,
ctx.window_size,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None
return dq, dkv, None, None, None, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
......@@ -393,6 +421,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
):
if softmax_scale is None:
......@@ -422,6 +451,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
......@@ -449,17 +479,28 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.causal,
ctx.window_size,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None, None, None
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
ctx,
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
......@@ -480,6 +521,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
......@@ -501,12 +543,13 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.causal,
ctx.window_size,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function):
......@@ -525,6 +568,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
causal,
window_size,
alibi_slopes,
deterministic,
return_softmax,
):
if softmax_scale is None:
......@@ -554,6 +598,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.causal = causal
ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod
......@@ -579,12 +624,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.causal,
ctx.window_size,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func(
......@@ -594,6 +640,7 @@ def flash_attn_qkvpacked_func(
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
......@@ -615,6 +662,8 @@ def flash_attn_qkvpacked_func(
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
......@@ -628,7 +677,14 @@ def flash_attn_qkvpacked_func(
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnQKVPackedFunc.apply(
qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
qkv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)
......@@ -640,6 +696,7 @@ def flash_attn_kvpacked_func(
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
......@@ -678,6 +735,8 @@ def flash_attn_kvpacked_func(
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
......@@ -691,7 +750,15 @@ def flash_attn_kvpacked_func(
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnKVPackedFunc.apply(
q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
q,
kv,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)
......@@ -704,6 +771,7 @@ def flash_attn_func(
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
......@@ -740,6 +808,8 @@ def flash_attn_func(
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
......@@ -753,7 +823,16 @@ def flash_attn_func(
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnFunc.apply(
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)
......@@ -766,6 +845,7 @@ def flash_attn_varlen_qkvpacked_func(
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
......@@ -790,6 +870,8 @@ def flash_attn_varlen_qkvpacked_func(
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
......@@ -811,6 +893,7 @@ def flash_attn_varlen_qkvpacked_func(
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)
......@@ -827,6 +910,7 @@ def flash_attn_varlen_kvpacked_func(
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
......@@ -871,6 +955,8 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
......@@ -895,6 +981,7 @@ def flash_attn_varlen_kvpacked_func(
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)
......@@ -912,6 +999,7 @@ def flash_attn_varlen_func(
causal=False,
window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
......@@ -954,6 +1042,8 @@ def flash_attn_varlen_func(
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
......@@ -979,6 +1069,7 @@ def flash_attn_varlen_func(
causal,
window_size,
alibi_slopes,
deterministic,
return_attn_probs,
)
......
......@@ -566,10 +566,12 @@ def get_dropout_fraction(
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
# @pytest.mark.parametrize("alibi", [False])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
# @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
......@@ -578,16 +580,16 @@ def get_dropout_fraction(
# @pytest.mark.parametrize("d", [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize("seqlen", [97])
# @pytest.mark.parametrize("seqlen", [512])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype):
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 8
batch_size = 4
nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
qkv = torch.randn(
......@@ -604,6 +606,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype)
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
if dropout_p > 0.0:
......@@ -712,6 +715,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype)
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True])
......@@ -725,7 +730,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype)
# @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype):
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = "cuda"
......@@ -760,6 +765,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
out = output_pad_fn(out_unpad)
......@@ -859,6 +865,8 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True])
......@@ -890,7 +898,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.17])
def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, mha_type, dtype, kvpacked
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
):
if (
max(seqlen_q, seqlen_k) >= 2048
......@@ -900,7 +908,7 @@ def test_flash_attn_output(
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 8
batch_size = 4
nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
......@@ -931,6 +939,7 @@ def test_flash_attn_output(
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
else:
......@@ -942,6 +951,7 @@ def test_flash_attn_output(
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
if dropout_p > 0.0:
......@@ -1114,6 +1124,8 @@ def test_flash_attn_output(
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mqa"])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True])
......@@ -1143,7 +1155,7 @@ def test_flash_attn_output(
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, mha_type, dtype, kvpacked
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
):
if (
max(seqlen_q, seqlen_k) >= 2048
......@@ -1153,7 +1165,7 @@ def test_flash_attn_varlen_output(
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 8
batch_size = 4
nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0
......@@ -1207,6 +1219,7 @@ def test_flash_attn_varlen_output(
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
else:
......@@ -1237,6 +1250,7 @@ def test_flash_attn_varlen_output(
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
out = output_pad_fn(out_unpad)
......@@ -1675,6 +1689,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [True])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True])
......@@ -1704,7 +1720,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, dtype):
def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype):
if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda"
......@@ -1729,6 +1745,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, al
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
)
out_ref, attn_ref = attention_ref(
......@@ -2224,3 +2241,152 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
assert not q.grad.isnan().any()
assert not k.grad.isnan().any()
assert not v.grad.isnan().any()
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 239),
(3, 799),
(127, 512),
(127, 513),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(1023, 1024),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)
g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
assert torch.equal(dv, dv0)
assert torch.equal(dk, dk0)
assert torch.equal(dq, dq0)
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [True])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 239),
(3, 799),
(127, 512),
(127, 513),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(1023, 1024),
],
)
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 2
nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
out = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
causal=causal,
window_size=window_size,
deterministic=True,
)
g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
assert torch.equal(dv, dv)
assert torch.equal(dk, dk)
assert torch.equal(dq, dq)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment