Unverified Commit 8f873cc6 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Implement softcapping. (#1025)

* Softcap v2 (fwd only).

* Some missing interface + remove overrides in tests.
parent 4e8d6006
......@@ -43,6 +43,7 @@ void set_params_fprop(Flash_fwd_params &params,
float softmax_scale,
int window_size_left,
int window_size_right,
const float softcap,
bool seqlenq_ngroups_swapped=false,
const bool unpadded_lse=false) {
......@@ -100,8 +101,19 @@ void set_params_fprop(Flash_fwd_params &params,
params.d_rounded = d_rounded;
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
#endif
if (softcap > 0.0) {
params.softcap = softmax_scale / softcap;
params.scale_softmax = softcap;
params.scale_softmax_log2 = softcap * M_LOG2E;
}else{
// Remove potential NaN
params.softcap = 0.0;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
}
// Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout;
......@@ -172,6 +184,7 @@ void set_params_dgrad(Flash_bwd_params &params,
float softmax_scale,
int window_size_left,
int window_size_right,
const float softcap,
bool deterministic,
const bool unpadded_lse) {
......@@ -187,6 +200,7 @@ void set_params_dgrad(Flash_bwd_params &params,
softmax_scale,
window_size_left,
window_size_right,
softcap,
false, // seqlenq_ngroups_swapped
unpadded_lse);
......@@ -332,6 +346,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
......@@ -453,7 +468,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
p_dropout,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
softcap
);
set_params_splitkv(params, batch_size, num_heads,
......@@ -521,6 +538,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
......@@ -688,6 +706,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
softmax_scale,
window_size_left,
window_size_right,
softcap,
seqlenq_ngroups_swapped,
/*unpadded_lse*/true);
params.total_q = total_q;
......@@ -776,6 +795,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) {
......@@ -940,6 +960,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_scale,
window_size_left,
window_size_right,
softcap,
deterministic,
/*unpadded_lse*/false);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
......@@ -1009,6 +1030,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) {
......@@ -1191,6 +1213,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_scale,
window_size_left,
window_size_right,
softcap,
deterministic,
/*unpadded_lse*/true);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
......@@ -1257,6 +1280,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits
) {
......@@ -1392,7 +1416,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
/*p_dropout=*/0.f,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
softcap
);
at::Tensor k, v, k_padded, v_padded;
if (k_.has_value()) {
......
......@@ -118,6 +118,7 @@ struct Flash_fwd_params : public Qkv_params {
// Local window size
int window_size_left, window_size_right;
float softcap;
// Random state.
at::PhiloxCudaState philox_args;
......
......@@ -22,6 +22,22 @@ namespace flash {
using namespace cute;
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
static_assert(Layout::rank == 3, "Only support 3D Tensor");
static_assert(decltype(size<0>(tensor))::value == 4, "First dimension must be 4");
#pragma unroll
for (int i=0; i < size<0>(tensor); ++i){ // MMA
#pragma unroll
for (int mi=0; mi < size<1>(tensor); ++mi){
#pragma unroll
for (int nj=0; nj < size<2>(tensor); ++nj){
tensor(i, mi, nj) = cutlass::fast_tanh(tensor(i, mi, nj) * softcap );
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
......@@ -45,7 +61,7 @@ __forceinline__ __device__ auto get_lse_tile(const Params &params, const int bid
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
......@@ -318,6 +334,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
smem_thr_copy_Q, smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
}
mask.template apply_mask<Is_causal, Is_even_MN>(
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
......@@ -381,6 +400,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();
__syncthreads();
......@@ -486,7 +508,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
using Element = typename Kernel_traits::Element;
......@@ -870,6 +892,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
smem_thr_copy_Q, smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
}
mask.template apply_mask<Is_causal, Is_even_MN>(
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
......@@ -941,6 +967,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();
__syncthreads();
......@@ -1054,7 +1083,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x;
// The block index for the batch.
......@@ -1070,12 +1099,12 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params>
inline __device__ void compute_attn_splitkv(const Params &params) {
const int m_block = blockIdx.x;
// The block index for the batch.
......@@ -1084,7 +1113,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
const int n_split_idx = Split ? blockIdx.y : 0;
const int num_n_splits = Split ? gridDim.y : 1;
flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -26,18 +26,18 @@
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
......@@ -67,25 +67,27 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
......@@ -109,18 +111,20 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
......
......@@ -56,6 +56,16 @@
#define EVENK_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
#define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
[&] { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
}()
#else
#define SOFTCAP_SWITCH BOOL_SWITCH
#endif
#ifdef FLASHATTENTION_DISABLE_LOCAL
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
......
......@@ -44,7 +44,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
......@@ -59,6 +59,7 @@ def _flash_attn_forward(
causal,
window_size[0],
window_size[1],
softcap,
return_softmax,
None,
)
......@@ -123,6 +124,7 @@ def _flash_attn_backward(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
rng_state=None,
......@@ -151,6 +153,7 @@ def _flash_attn_backward(
causal,
window_size[0],
window_size[1],
softcap,
deterministic,
None,
rng_state,
......@@ -176,6 +179,7 @@ def _flash_attn_varlen_backward(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
rng_state=None,
......@@ -209,6 +213,7 @@ def _flash_attn_varlen_backward(
causal,
window_size[0],
window_size[1],
softcap,
deterministic,
None,
rng_state,
......@@ -227,6 +232,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
......@@ -241,6 +247,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
)
......@@ -249,6 +256,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
......@@ -272,6 +280,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
......@@ -433,6 +442,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
......@@ -451,6 +461,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
......@@ -464,6 +475,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
......@@ -492,6 +504,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
......@@ -512,6 +525,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
......@@ -526,6 +540,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
)
......@@ -534,6 +549,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
......@@ -556,6 +572,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
......@@ -581,6 +598,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
......@@ -600,6 +618,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=block_table,
......@@ -613,6 +632,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
......@@ -639,6 +659,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
......@@ -655,6 +676,7 @@ def flash_attn_qkvpacked_func(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # <=0.0 means deactivate
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
......@@ -676,6 +698,7 @@ def flash_attn_qkvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
......@@ -698,6 +721,7 @@ def flash_attn_qkvpacked_func(
softmax_scale,
causal,
window_size,
softcapping,
alibi_slopes,
deterministic,
return_attn_probs,
......@@ -711,6 +735,7 @@ def flash_attn_kvpacked_func(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
......@@ -748,6 +773,7 @@ def flash_attn_kvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
......@@ -772,6 +798,7 @@ def flash_attn_kvpacked_func(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
......@@ -786,6 +813,7 @@ def flash_attn_func(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
......@@ -846,6 +874,7 @@ def flash_attn_func(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
......@@ -860,6 +889,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
......@@ -884,6 +914,7 @@ def flash_attn_varlen_qkvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
......@@ -908,6 +939,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
......@@ -925,6 +957,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
......@@ -968,6 +1001,7 @@ def flash_attn_varlen_kvpacked_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
......@@ -996,6 +1030,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
......@@ -1014,6 +1049,7 @@ def flash_attn_varlen_func(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
......@@ -1056,6 +1092,7 @@ def flash_attn_varlen_func(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
......@@ -1085,6 +1122,7 @@ def flash_attn_varlen_func(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
......@@ -1106,6 +1144,7 @@ def flash_attn_with_kvcache(
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
alibi_slopes=None,
num_splits=0,
......@@ -1177,6 +1216,7 @@ def flash_attn_with_kvcache(
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
......@@ -1226,6 +1266,7 @@ def flash_attn_with_kvcache(
causal,
window_size[0],
window_size[1],
softcap,
rotary_interleaved,
num_splits,
)
......
......@@ -203,6 +203,7 @@ if not SKIP_CUDA_BUILD:
# "-DFLASHATTENTION_DISABLE_BACKWARD",
# "-DFLASHATTENTION_DISABLE_DROPOUT",
# "-DFLASHATTENTION_DISABLE_ALIBI",
# "-DFLASHATTENTION_DISABLE_SOFTCAP",
# "-DFLASHATTENTION_DISABLE_UNEVEN_K",
# "-DFLASHATTENTION_DISABLE_LOCAL",
]
......
......@@ -216,6 +216,7 @@ def attention_ref(
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
......@@ -253,6 +254,10 @@ def attention_ref(
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if softcap > 0:
scores /= softcap
scores = scores.tanh()
scores *= softcap
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
......@@ -877,8 +882,9 @@ def test_flash_attn_varlen_qkvpacked(
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.17])
@pytest.mark.parametrize("softcap", [0.0, 50.0])
def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
):
if (
max(seqlen_q, seqlen_k) >= 2048
......@@ -894,6 +900,9 @@ def test_flash_attn_output(
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if softcap > 0:
# Ensure the values of qk are at least within softcap range.
q = q * softcap
if kvpacked:
kv = torch.randn(
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True
......@@ -918,6 +927,7 @@ def test_flash_attn_output(
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
......@@ -930,6 +940,7 @@ def test_flash_attn_output(
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
......@@ -984,6 +995,7 @@ def test_flash_attn_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_kvpacked_ref(
q,
......@@ -995,6 +1007,7 @@ def test_flash_attn_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
)
......@@ -1010,6 +1023,7 @@ def test_flash_attn_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_ref(
q,
......@@ -1022,6 +1036,7 @@ def test_flash_attn_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
)
......@@ -1133,9 +1148,10 @@ def test_flash_attn_output(
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize("softcap", [0.0, 50.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
):
if (
max(seqlen_q, seqlen_k) >= 2048
......@@ -1151,6 +1167,9 @@ def test_flash_attn_varlen_output(
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
if softcap > 0:
# Ensure the values of qk are at least within softcap range.
q = q * softcap
if kvpacked:
kv = torch.randn(
......@@ -1199,6 +1218,7 @@ def test_flash_attn_varlen_output(
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
......@@ -1230,6 +1250,7 @@ def test_flash_attn_varlen_output(
dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
return_attn_probs=True,
......@@ -1289,6 +1310,7 @@ def test_flash_attn_varlen_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_kvpacked_ref(
q,
......@@ -1300,6 +1322,7 @@ def test_flash_attn_varlen_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
)
......@@ -1315,6 +1338,7 @@ def test_flash_attn_varlen_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
)
out_pt, attn_pt = attention_ref(
q,
......@@ -1327,6 +1351,7 @@ def test_flash_attn_varlen_output(
dropout_mask,
causal=causal,
window_size=window_size,
softcap=softcap,
upcast=False,
reorder_ops=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment