Unverified Commit e4f726fc authored by Sanghun Cho's avatar Sanghun Cho Committed by GitHub
Browse files

Support alibi, by Sanghun Cho from Kakao Brain



* hard-code alibi in fwd

* use params.h as hun_heads

* hard-code alibi in bwd

* add alibi on/off option

* compute alibi_start, ratio outside of kernels

* fix minor merge conflict

* add test_alibi.py

* change apply_alibi() location before masking

* add alibi in splitkv kernel

* fix backward func # of returns

* add out-of-bound check in apply_alibi()

* update test_alibi.py

* update test_alibi.py for kvcache

* simplify alibi parameter interface

* fix performance issue
by computing alibi outside of branch

* update test_flash_attn_varlen_func() for left padding

* implement alibi_slopes (b, nh) loading

* optimize apply_alibi() a bit

* update test cases for alibi_slopes loading

* reflect stylistic comments

* disable "seqlenq_ngroups_swapped" when using alibi

---------
Co-authored-by: default avatarmonk.detective <monk.detective@kakaobrain.com>
parent cd089597
...@@ -258,6 +258,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -258,6 +258,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
bool is_causal, bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // batch_size x num_heads
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
...@@ -301,7 +302,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -301,7 +302,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0; // TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !(alibi_slopes_.has_value());
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k; const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
...@@ -409,6 +411,19 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -409,6 +411,19 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
params.philox_args = gen->philox_cuda_state(counter_offset); params.philox_args = gen->philox_cuda_state(counter_offset);
} }
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
} else {
params.has_alibi = false;
}
if (seqlen_k > 0) { if (seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream); run_mha_fwd(params, stream);
...@@ -449,6 +464,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -449,6 +464,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // b x num_heads
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
...@@ -591,6 +607,19 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -591,6 +607,19 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
params.philox_args = gen->philox_cuda_state(counter_offset); params.philox_args = gen->philox_cuda_state(counter_offset);
} }
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
} else {
params.has_alibi = false;
}
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream); run_mha_fwd(params, stream);
...@@ -640,6 +669,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -640,6 +669,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // batch_size x num_heads
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) { c10::optional<at::Tensor> &rng_state) {
...@@ -813,6 +843,19 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -813,6 +843,19 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
params.rng_state[1] = std::get<1>(seeds); params.rng_state[1] = std::get<1>(seeds);
} }
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
} else {
params.has_alibi = false;
}
if (seqlen_q > 0) { if (seqlen_q > 0) {
launch(params, stream, /*configure=*/false); launch(params, stream, /*configure=*/false);
} else { } else {
...@@ -856,6 +899,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -856,6 +899,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // b x num_heads
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) { c10::optional<at::Tensor> &rng_state) {
...@@ -1045,6 +1089,19 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1045,6 +1089,19 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
params.rng_state[1] = std::get<1>(seeds); params.rng_state[1] = std::get<1>(seeds);
} }
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
} else {
params.has_alibi = false;
}
launch(params, stream, /*configure=*/false); launch(params, stream, /*configure=*/false);
// For MQA/GQA we need to sum dK and dV across the groups // For MQA/GQA we need to sum dK and dV across the groups
...@@ -1077,7 +1134,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1077,7 +1134,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits int num_splits,
c10::optional<at::Tensor> &alibi_slopes_ // batch_size x num_heads
) { ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
...@@ -1121,7 +1179,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1121,7 +1179,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0; // TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !(alibi_slopes_.has_value());
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k; const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
...@@ -1283,6 +1342,19 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1283,6 +1342,19 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
params.oaccum_ptr = out_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr();
} }
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
} else {
params.has_alibi = false;
}
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value()); run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value());
......
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout>
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset_,
const int max_seqlen_q,
const int warp_row_stride,
const int head_idx,
const float softmax_scale,
const float alibi_slope) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int row_idx_offset = row_idx_offset_;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
const float alibi_slope_unscaled = alibi_slope / softmax_scale;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
const float alibi = alibi_slope_unscaled * col_idx;
if (col_idx < max_seqlen_k && row_idx < max_seqlen_q) {
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi;
}
}
}
}
}
}
} // namespace flash
\ No newline at end of file
...@@ -130,6 +130,13 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -130,6 +130,13 @@ struct Flash_fwd_params : public Qkv_params {
bool is_rotary_interleaved; bool is_rotary_interleaved;
int num_splits; // For split-KV version int num_splits; // For split-KV version
// float alibi_start;
// float alibi_ratio;
bool has_alibi;
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "utils.h" #include "utils.h"
#include "softmax.h" #include "softmax.h"
#include "alibi.h"
namespace flash { namespace flash {
using namespace cute; using namespace cute;
...@@ -422,7 +424,7 @@ inline __device__ void convert_dKV(const Params &params) { ...@@ -422,7 +424,7 @@ inline __device__ void convert_dKV(const Params &params) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) { inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -790,6 +792,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -790,6 +792,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
clear(acc_dv); clear(acc_dv);
clear(acc_dk); clear(acc_dk);
float alibi_slope = 0.0f;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
}
for (; m_block >= m_block_min; --m_block) { for (; m_block >= m_block_min; --m_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
clear(acc_s); clear(acc_s);
...@@ -813,6 +828,20 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -813,6 +828,20 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread(32, 0)) { print(scores); } // if (cute::thread(32, 0)) { print(scores); }
if (Has_alibi) {
flash::apply_alibi(
scores,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
AtomLayoutMS * 16,
bidh, params.scale_softmax,
alibi_slope
);
}
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// actual_seqlen_k, because acc_s would be some finite value for those indices. // actual_seqlen_k, because acc_s would be some finite value for those indices.
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0, // In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
...@@ -849,6 +878,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -849,6 +878,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
} }
} }
// if (cute::thread(32, 0)) { print(scores); } // if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value. // Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2); flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
...@@ -1114,7 +1144,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -1114,7 +1144,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_N, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) { inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -1373,6 +1403,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1373,6 +1403,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
clear(acc_dq); clear(acc_dq);
float alibi_slope = 0.0f;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
}
for (; n_block >= 0; --n_block) { for (; n_block >= 0; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M_SdP, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M_SdP, MMA_N)
clear(acc_s); clear(acc_s);
...@@ -1384,6 +1427,20 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1384,6 +1427,20 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) {
flash::apply_alibi(
scores,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
AtomLayoutMS * 16,
bidh, params.scale_softmax,
alibi_slope
);
}
// We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would // We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would
// be some finite value for those indices. In the end when we multiply with K to get dQ, // be some finite value for those indices. In the end when we multiply with K to get dQ,
// the corresponding values of K would be 0, so the result would still be correct. // the corresponding values of K would be 0, so the result would still be correct.
...@@ -1394,6 +1451,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1394,6 +1451,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
AtomLayoutMS * 16); AtomLayoutMS * 16);
} }
// Compute the exponential value. // Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2); flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
if (Is_dropout) { if (Is_dropout) {
...@@ -1536,7 +1594,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1536,7 +1594,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv(const Params &params) { inline __device__ void compute_dq_dk_dv(const Params &params) {
// The block index for the batch. // The block index for the batch.
...@@ -1550,20 +1608,20 @@ inline __device__ void compute_dq_dk_dv(const Params &params) { ...@@ -1550,20 +1608,20 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
if (n_block_max == 1) { if (n_block_max == 1) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0); compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0);
} else { } else {
// Iterating backward from n_block_max - 1 to 0 might save 1 register // Iterating backward from n_block_max - 1 to 0 might save 1 register
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1); compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1);
for (int n_block = n_block_max - 2; n_block > 0; n_block--) { for (int n_block = n_block_max - 2; n_block > 0; n_block--) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block); compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block);
} }
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0); compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0);
} }
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) { inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
const int n_block = blockIdx.x; const int n_block = blockIdx.x;
...@@ -1572,12 +1630,12 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) { ...@@ -1572,12 +1630,12 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.z; const int bidh = blockIdx.z;
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block); compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_N, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params &params) { inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
...@@ -1586,7 +1644,7 @@ inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params &params) { ...@@ -1586,7 +1644,7 @@ inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params &params) {
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.z; const int bidh = blockIdx.z;
compute_dq_dk_dv_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K>(params, bidb, bidh, m_block); compute_dq_dk_dv_1rowblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_N, Is_even_K>(params, bidb, bidh, m_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -18,20 +18,20 @@ __global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) { ...@@ -18,20 +18,20 @@ __global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) {
flash::clear_dKVaccum<Kernel_traits>(params); flash::clear_dKVaccum<Kernel_traits>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) { __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params); flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { __global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K>(params); flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_N, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) { __global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) {
flash::compute_dq_dk_dv_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K>(params); flash::compute_dq_dk_dv_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_N, Is_even_K>(params);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
...@@ -64,11 +64,12 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -64,11 +64,12 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false // If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>; auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal && !Is_local, Is_local, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true>; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) { if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute( C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
...@@ -79,6 +80,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -79,6 +80,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
}); });
}); });
}); });
});
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>; auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
if (Kernel_traits::kSmemdQSize >= 48 * 1024) { if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
...@@ -107,8 +109,9 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -107,8 +109,9 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst && IsEvenKConst, IsEvenKConst>; auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
if (smem_size_dq_dk_dv >= 48 * 1024) { if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute( C10_CUDA_CHECK(cudaFuncSetAttribute(
...@@ -119,6 +122,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -119,6 +122,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
}); });
}); });
}); });
});
auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>; auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
if (Kernel_traits::kSmemKVSize >= 48 * 1024) { if (Kernel_traits::kSmemKVSize >= 48 * 1024) {
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "utils.h" #include "utils.h"
#include "softmax.h" #include "softmax.h"
#include "alibi.h"
namespace flash { namespace flash {
using namespace cute; using namespace cute;
...@@ -71,7 +73,7 @@ inline __device__ void write_softmax_to_gmem( ...@@ -71,7 +73,7 @@ inline __device__ void write_softmax_to_gmem(
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) { inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -326,6 +328,22 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -326,6 +328,22 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration. // We will have at least 1 "masking" iteration.
float alibi_slope = 0.0f;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local) constexpr int n_masking_steps = (!Is_causal && !Is_local)
...@@ -362,6 +380,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -362,6 +380,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// We don't put the masking before the matmul S = Q K^T because we don't clear sK // We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN. // can produce Inf / NaN.
if (Has_alibi) {
flash::apply_alibi(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope
);
}
if (!Is_causal && !Is_local) { if (!Is_causal && !Is_local) {
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else { } else {
...@@ -466,6 +498,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -466,6 +498,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) {
flash::apply_alibi(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope
);
}
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
flash::apply_mask_local( flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k, scores, n_block * kBlockN, binfo.actual_seqlen_k,
...@@ -474,6 +520,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -474,6 +520,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
params.window_size_left, params.window_size_right params.window_size_left, params.window_size_right
); );
} }
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores); Tensor rP = flash::convert_type<Element>(scores);
...@@ -581,7 +628,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -581,7 +628,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -909,6 +956,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -909,6 +956,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration. // We will have at least 1 "masking" iteration.
float alibi_slope = 0.0f;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local) constexpr int n_masking_steps = (!Is_causal && !Is_local)
...@@ -941,6 +1004,20 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -941,6 +1004,20 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) {
flash::apply_alibi(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope
);
}
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK // We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
...@@ -1020,6 +1097,20 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1020,6 +1097,20 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) {
flash::apply_alibi(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope
);
}
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
flash::apply_mask_local( flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k, scores, n_block * kBlockN, binfo.actual_seqlen_k,
...@@ -1131,7 +1222,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1131,7 +1222,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) { inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
...@@ -1147,12 +1238,12 @@ inline __device__ void compute_attn(const Params &params) { ...@@ -1147,12 +1238,12 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of // the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block); flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
inline __device__ void compute_attn_splitkv(const Params &params) { inline __device__ void compute_attn_splitkv(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
...@@ -1161,7 +1252,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) { ...@@ -1161,7 +1252,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
const int n_split_idx = Split ? blockIdx.y : 0; const int n_split_idx = Split ? blockIdx.y : 0;
const int num_n_splits = Split ? gridDim.y : 1; const int num_n_splits = Split ? gridDim.y : 1;
flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits); flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -10,15 +10,15 @@ ...@@ -10,15 +10,15 @@
#include "flash.h" #include "flash.h"
#include "flash_fwd_kernel.h" #include "flash_fwd_kernel.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) { __global__ void flash_fwd_kernel(Flash_fwd_params params) {
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params); flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
} }
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params); flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
} }
template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K> template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K>
...@@ -45,12 +45,13 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -45,12 +45,13 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
// Will only return softmax if dropout, to reduce compilation time. // Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates // If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false // If Is_local, set Is_causal to false
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>; auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>; // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) { if (smem_size >= 48 * 1024) {
...@@ -67,6 +68,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -67,6 +68,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
}); });
}); });
}); });
});
} }
template<typename Kernel_traits> template<typename Kernel_traits>
...@@ -84,10 +86,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -84,10 +86,11 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false // If Is_local, set Is_causal to false
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>; auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>; // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>; // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if (smem_size >= 48 * 1024) { if (smem_size >= 48 * 1024) {
...@@ -102,6 +105,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -102,6 +105,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
}); });
}); });
}); });
});
if (params.num_splits > 1) { if (params.num_splits > 1) {
// We want kBlockM to be as small as possible for more parallelism. // We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
......
...@@ -43,7 +43,7 @@ def _get_block_size(device, head_dim, is_dropout, is_causal): ...@@ -43,7 +43,7 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
return (128, 64) if is_sm80 else (64, 64) return (128, 64) if is_sm80 else (64, 64)
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax): def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
...@@ -56,6 +56,7 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, ...@@ -56,6 +56,7 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size,
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
return_softmax, return_softmax,
None, None,
) )
...@@ -74,6 +75,7 @@ def _flash_attn_varlen_forward( ...@@ -74,6 +75,7 @@ def _flash_attn_varlen_forward(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_softmax, return_softmax,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
...@@ -94,6 +96,7 @@ def _flash_attn_varlen_forward( ...@@ -94,6 +96,7 @@ def _flash_attn_varlen_forward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
return_softmax, return_softmax,
None, None,
) )
...@@ -116,6 +119,7 @@ def _flash_attn_backward( ...@@ -116,6 +119,7 @@ def _flash_attn_backward(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
rng_state=None, rng_state=None,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
...@@ -136,6 +140,7 @@ def _flash_attn_backward( ...@@ -136,6 +140,7 @@ def _flash_attn_backward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
None, None,
rng_state, rng_state,
) )
...@@ -160,6 +165,7 @@ def _flash_attn_varlen_backward( ...@@ -160,6 +165,7 @@ def _flash_attn_varlen_backward(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
rng_state=None, rng_state=None,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
...@@ -185,6 +191,7 @@ def _flash_attn_varlen_backward( ...@@ -185,6 +191,7 @@ def _flash_attn_varlen_backward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
None, None,
rng_state, rng_state,
) )
...@@ -195,7 +202,7 @@ def _flash_attn_varlen_backward( ...@@ -195,7 +202,7 @@ def _flash_attn_varlen_backward(
class FlashAttnQKVPackedFunc(torch.autograd.Function): class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, return_softmax): def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
...@@ -206,6 +213,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -206,6 +213,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
...@@ -213,6 +221,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -213,6 +221,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -234,10 +243,11 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -234,10 +243,11 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes,
rng_state=rng_state, rng_state=rng_state,
) )
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None return dqkv, None, None, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...@@ -251,6 +261,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -251,6 +261,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_softmax, return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
...@@ -267,6 +278,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -267,6 +278,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
...@@ -275,6 +287,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -275,6 +287,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -300,15 +313,16 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -300,15 +313,16 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes,
rng_state=rng_state, rng_state=rng_state,
) )
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, return_softmax): def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
...@@ -319,6 +333,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -319,6 +333,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
...@@ -326,6 +341,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -326,6 +341,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -348,11 +364,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -348,11 +364,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None return dq, dkv, None, None, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...@@ -369,6 +386,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -369,6 +386,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_softmax, return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
...@@ -385,6 +403,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -385,6 +403,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward( ctx.save_for_backward(
...@@ -396,6 +415,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -396,6 +415,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -422,16 +442,17 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -422,16 +442,17 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax): def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
...@@ -442,6 +463,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -442,6 +463,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
...@@ -449,6 +471,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -449,6 +471,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -469,12 +492,13 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -469,12 +492,13 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function): class FlashAttnVarlenFunc(torch.autograd.Function):
...@@ -492,6 +516,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -492,6 +516,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_softmax, return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
...@@ -508,6 +533,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -508,6 +533,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward( ctx.save_for_backward(
...@@ -519,6 +545,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -519,6 +545,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -543,12 +570,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -543,12 +570,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func( def flash_attn_qkvpacked_func(
...@@ -557,6 +585,7 @@ def flash_attn_qkvpacked_func( ...@@ -557,6 +585,7 @@ def flash_attn_qkvpacked_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -589,7 +618,7 @@ def flash_attn_qkvpacked_func( ...@@ -589,7 +618,7 @@ def flash_attn_qkvpacked_func(
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnQKVPackedFunc.apply( return FlashAttnQKVPackedFunc.apply(
qkv, dropout_p, softmax_scale, causal, window_size, return_attn_probs qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
) )
...@@ -600,6 +629,7 @@ def flash_attn_kvpacked_func( ...@@ -600,6 +629,7 @@ def flash_attn_kvpacked_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -648,7 +678,7 @@ def flash_attn_kvpacked_func( ...@@ -648,7 +678,7 @@ def flash_attn_kvpacked_func(
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnKVPackedFunc.apply( return FlashAttnKVPackedFunc.apply(
q, kv, dropout_p, softmax_scale, causal, window_size, return_attn_probs q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
) )
...@@ -660,6 +690,7 @@ def flash_attn_func( ...@@ -660,6 +690,7 @@ def flash_attn_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -706,7 +737,7 @@ def flash_attn_func( ...@@ -706,7 +737,7 @@ def flash_attn_func(
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnFunc.apply( return FlashAttnFunc.apply(
q, k, v, dropout_p, softmax_scale, causal, window_size, return_attn_probs q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
) )
...@@ -718,6 +749,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -718,6 +749,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -760,6 +792,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -760,6 +792,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_attn_probs, return_attn_probs,
) )
...@@ -775,6 +808,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -775,6 +808,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -839,6 +873,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -839,6 +873,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_attn_probs, return_attn_probs,
) )
...@@ -855,6 +890,7 @@ def flash_attn_varlen_func( ...@@ -855,6 +890,7 @@ def flash_attn_varlen_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -918,6 +954,7 @@ def flash_attn_varlen_func( ...@@ -918,6 +954,7 @@ def flash_attn_varlen_func(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_attn_probs, return_attn_probs,
) )
...@@ -937,6 +974,7 @@ def flash_attn_with_kvcache( ...@@ -937,6 +974,7 @@ def flash_attn_with_kvcache(
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
rotary_interleaved=True, rotary_interleaved=True,
num_splits=0, num_splits=0,
alibi_slopes=None,
): ):
""" """
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...@@ -1041,5 +1079,6 @@ def flash_attn_with_kvcache( ...@@ -1041,5 +1079,6 @@ def flash_attn_with_kvcache(
window_size[1], window_size[1],
rotary_interleaved, rotary_interleaved,
num_splits, num_splits,
alibi_slopes
) )
return out return out
import math
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn import (flash_attn_func, flash_attn_kvpacked_func,
flash_attn_qkvpacked_func, flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache)
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size
from flash_attn.flash_attn_triton import \
flash_attn_func as flash_attn_func_triton
from flash_attn.layers.rotary import apply_rotary_emb
MAX_HEADDIM_SM8x = 192
is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
def generate_alibi(max_seq_len, num_attention_heads, tp_world_size, tp_index, key_padding_mask=None, device="cuda"):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2 ** (-2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
:n - closest_power_of_2]
slopes = torch.tensor(get_slopes(num_attention_heads)).to(device=device)
# Select the part of the tensor that corresponds to our tensor parallel index.
assert (num_attention_heads/tp_world_size).is_integer(
), "it works only when (num_attention_heads/tp_world_size) is integer"
nh_tp = num_attention_heads // tp_world_size
slopes = slopes[nh_tp * tp_index:nh_tp * (tp_index + 1)]
if (key_padding_mask is None):
arange_tensor = rearrange(torch.arange(max_seq_len), "sqk -> 1 sqk").to(device=device)
else:
arange_tensor = (key_padding_mask.cumsum(dim=-1, dtype=slopes.dtype) - 1) \
.masked_fill_(~key_padding_mask, torch.finfo(torch.float).min).to(device=device)
arange_tensor = rearrange(arange_tensor, 'b sqk -> b 1 1 sqk')
# (1, nheads, 1, seqlen_k) or (batch, nheads, 1, seqlen_k)
alibi_tensor = rearrange(slopes, 'nh -> 1 nh 1 1') * arange_tensor
return alibi_tensor, slopes
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", right_padding=True):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen,
device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
)
elif mode == "third":
lengths = torch.randint(
max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
if right_padding:
padding_mask = (
repeat(torch.arange(max_seqlen, device=device),
"s -> b s", b=batch_size) < lengths
)
else:
padding_mask = (
repeat(torch.arange(start=max_seqlen-1, end=-1, step=-1, device=device),
"s -> b s", b=batch_size) < lengths
)
return padding_mask
def generate_qkv(
q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask)
def output_pad_fn(output_unpad): return pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
)
max_seqlen_q = seqlen_q
def output_pad_fn(output_unpad): return rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(
k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
def dqkv_pad_fn(dqkv_unpad): return pad_input(
dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
def dqkv_pad_fn(dqkv_unpad): return rearrange(
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
def dkv_pad_fn(dkv_unpad): return pad_input(
dkv_unpad, indices_k, batch_size, seqlen_k)
else:
def dkv_pad_fn(dkv_unpad): return rearrange(
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else:
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
def dk_pad_fn(dk_unpad): return pad_input(
dk_unpad, indices_k, batch_size, seqlen_k)
else:
def dk_pad_fn(dk_unpad): return rearrange(
dk_unpad, "(b s) h d -> b s h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
def construct_local_mask(
seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
query_padding_mask=None,
key_padding_mask=None,
device=None,
):
row_idx = rearrange(torch.arange(
seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(
col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
reorder_ops=False,
bias=None
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if bias is not None:
bias = bias.to(scores.dtype)
scores += bias
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask,
"b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
attention = torch.softmax(scores, dim=-1)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(
torch.all(local_mask, dim=-1, keepdim=True), 0.0)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum(
"bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(
rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def attention_kvpacked_ref(
q,
kv,
query_padding_mask=None,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
reorder_ops=False,
):
return attention_ref(
q,
kv[:, :, 0],
kv[:, :, 1],
query_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
window_size=window_size,
reorder_ops=reorder_ops,
)
def attention_qkvpacked_ref(
qkv,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
reorder_ops=False,
):
return attention_ref(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
key_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
window_size=window_size,
reorder_ops=reorder_ops,
)
def generate_sparsity_mask(seqlen, sparsity=0.3):
repeats = seqlen // 16 // 2
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
nrow, ncol = seqlen // 16, seqlen // 256
mask = torch.rand(nrow, ncol, device="cuda") < sparsity
return mask
def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
blockmask: (seqlen / 16, seqlen / 256)
attn_mask: (batch_size, seqlen)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen, seqlen)
Output:
output: (batch_size, seqlen, nheads, head_dim)
attention: softmax after dropout
"""
q, k, v = qkv.float().unbind(dim=2)
d = qkv.shape[-1]
seqlen = qkv.shape[1]
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)")
blockmask = blockmask[:seqlen, :seqlen]
scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf"))
attention = torch.softmax(scores, dim=-1)
attention = attention.masked_fill(
rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0)
attention = attention.masked_fill_(
rearrange(~blockmask, "t s -> 1 1 t s"), 0.0)
attention_drop = attention.masked_fill(
~dropout_mask, 0.0) / (1 - dropout_p)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0)
return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
def convert_flash_attn_S_to_softmax(
S,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
head_dim,
is_dropout,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""FlashAttention stores the S matrix in a different way.
Arguments:
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
query_padding_mask: (batch_size, seqlen_q_rounded)
key_padding_mask: (batch_size, seqlen_k_rounded)
"""
if causal:
window_size = (window_size[0], 0)
seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
warps_n = 4
blocksize_m, blocksize_n = _get_block_size(
S.device, head_dim, is_dropout, causal)
nblocks_n = (seqlen_k_rounded + blocksize_n - 1) // blocksize_n
nblocks_m = (seqlen_q_rounded + blocksize_m - 1) // blocksize_m
mmas_n = (blocksize_n + 16 - 1) // 16
S_flat = rearrange(
S,
"b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)",
blocksize_m=blocksize_m,
blocksize_n=blocksize_n,
)
S_converted = rearrange(
S_flat,
"b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)",
mmas_n=mmas_n,
warps_n=warps_n,
eight=8,
c0=2,
c1=2,
c2=2,
four=4,
)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
S.device,
)
local_mask = F.pad(
local_mask,
(0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
value=True,
)
S_converted.masked_fill_(local_mask, 0.0)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
seqlen_q_og = (
query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
)
if query_padding_mask is not None:
query_padding_mask = F.pad(
query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
S_converted = S_converted.masked_fill(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
if key_padding_mask is not None:
key_padding_mask = F.pad(
key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
S_converted = S_converted.masked_fill(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
return S_converted[:, :, :seqlen_q, :seqlen_k]
def normalize_flash_attn_S(
attn_unnorm,
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
is_dropout=False,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim)
key_padding_mask: (batch_size, seqlen_q)
Output:
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
"""
if causal:
window_size = (window_size[0], 0)
q, k, v = q.float(), k.float(), v.float()
_, seqlen_q, _, head_dim = q.shape
seqlen_k = k.shape[1]
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask,
"b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
_, block_size_n = _get_block_size(
scores.device, head_dim, is_dropout, causal)
scores_block = scores.split(block_size_n, dim=-1)
lse_block = torch.stack([torch.logsumexp(s, dim=-1)
for s in scores_block], dim=-1)
lse = torch.logsumexp(lse_block, dim=-1)
# lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
# so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
lse[lse == float("-inf")] = float("inf")
scores_max_block = torch.stack(
[torch.amax(s, dim=-1) for s in scores_block], dim=-1)
cummax_block = torch.cummax(
scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
attn_norm = torch.cat(
[
a * rearrange(torch.exp(m - lse), "b h s -> b h s 1")
for a, m in zip(attn_unnorm_block, cummax_block)
],
dim=-1,
)
if query_padding_mask is not None:
attn_norm.masked_fill_(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
return attn_norm.to(dtype=attn_unnorm.dtype)
def get_dropout_fraction(
dropout_mask,
query_padding_mask=None,
key_padding_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
"""
if causal:
window_size = (window_size[0], 0)
batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
dropped = ~dropout_mask
valid = torch.ones_like(dropout_mask)
if query_padding_mask is not None:
dropped.masked_fill_(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
valid.masked_fill_(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
if key_padding_mask is not None:
dropped.masked_fill_(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
valid.masked_fill_(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
dropout_mask.device,
)
dropped.masked_fill_(local_mask, False)
valid.masked_fill_(local_mask, False)
dropped_total = dropped.sum()
return dropped.sum() / valid.sum()
@pytest.mark.parametrize(
"dtype", [torch.float16]
)
@pytest.mark.parametrize(
"b_sq",
[
(32, 512),
(16, 1024),
(8, 2048),
(4, 4096),
(2, 8192),
(1, 16384)
]
)
@pytest.mark.parametrize(
"nh_hd",
[
(32, 64),
(16, 128),
(40, 128) # non power of 2 nh
]
)
@pytest.mark.parametrize(
"tp_world_size", [1, 2, 4]
)
def test_flash_attn_func(b_sq, nh_hd, tp_world_size, dtype):
b, sq = b_sq
nh, hd = nh_hd
nh_tp = nh // tp_world_size
q, k, v = [torch.randn(b, sq, nh_tp, hd, device="cuda",
dtype=dtype, requires_grad=True) for _ in range(3)]
dout = torch.rand_like(q)
for tp_index in range(tp_world_size):
alibi, alibi_slopes = generate_alibi(
max_seq_len=sq,
num_attention_heads=nh,
tp_world_size=tp_world_size,
tp_index=tp_index,
key_padding_mask=None,
device="cuda"
)
triton_out = flash_attn_func_triton(
q, k, v, alibi, True, hd**(-0.5))
triton_out.backward(dout)
triton_dq, q.grad = q.grad.clone(), None
triton_dk, k.grad = k.grad.clone(), None
triton_dv, v.grad = v.grad.clone(), None
flash_out = flash_attn_func(q, k, v, causal=True, alibi_slopes=repeat(alibi_slopes, "nh -> b nh", b=b))
flash_out.backward(dout)
flash_dq, q.grad = q.grad.clone(), None
flash_dk, k.grad = k.grad.clone(), None
flash_dv, v.grad = v.grad.clone(), None
assert torch.allclose(flash_out, triton_out, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dq, triton_dq, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dk, triton_dk, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dv, triton_dv, atol=1e-2, rtol=0.)
@pytest.mark.parametrize(
"dtype", [torch.float16]
)
@pytest.mark.parametrize(
"right_padding", [True, False]
)
@pytest.mark.parametrize(
"b_sq",
[
(32, 512),
(16, 1024),
(8, 2048),
(4, 4096),
(2, 8192),
(1, 16384)
]
)
@pytest.mark.parametrize(
"nh_hd",
[
(32, 64),
(16, 128),
(40, 128) # non power of 2 nh
]
)
@pytest.mark.parametrize(
"tp_world_size", [1, 2, 4]
)
def test_flash_attn_varlen_func(b_sq, nh_hd, tp_world_size, right_padding, dtype):
b, sqk = b_sq
nh, hd = nh_hd
nh_tp = nh // tp_world_size
# flash_attn_func_triton(), flash-attention v2 (above v2.1) causal logic are different
# so only (seqlen_q == 1, causal=False to triton ver.) shows correct results
# https://github.com/huggingface/text-generation-inference/blob/v1.1.1/server/text_generation_server/models/custom_modeling/mpt_modeling.py#L53-L63
q = torch.randn(b, 1, nh_tp, hd, device="cuda", dtype=dtype, requires_grad=True)
k, v = [torch.randn(b, sqk, nh_tp, hd, device="cuda",
dtype=dtype, requires_grad=True) for _ in range(2)]
dout = torch.rand_like(q)
padding_mask = generate_random_padding_mask(sqk, b, "cuda", "random", right_padding)
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, None, padding_mask, kvpacked=False)
for tp_index in range(tp_world_size):
alibi, alibi_slopes = generate_alibi(
max_seq_len=sqk,
num_attention_heads=nh,
tp_world_size=tp_world_size,
tp_index=tp_index,
key_padding_mask=padding_mask,
device="cuda"
)
triton_out = flash_attn_func_triton(
q, k, v, alibi, False, hd**(-0.5))
triton_out.backward(dout)
triton_dq, q.grad = q.grad.clone(), None
triton_dk, k.grad = k.grad.clone(), None
triton_dv, v.grad = v.grad.clone(), None
flash_out_unpad = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
causal=True,
alibi_slopes=repeat(alibi_slopes, "nh -> b nh", b=b)
)
flash_out = output_pad_fn(flash_out_unpad)
flash_out.backward(dout)
flash_dq_unpad, q_unpad.grad = q_unpad.grad.clone(), None
flash_dk_unpad, k_unpad.grad = k_unpad.grad.clone(), None
flash_dv_unpad, v_unpad.grad = v_unpad.grad.clone(), None
flash_dq = dq_pad_fn(flash_dq_unpad)
flash_dk = dk_pad_fn(flash_dk_unpad)
flash_dv = dk_pad_fn(flash_dv_unpad)
assert torch.allclose(flash_out, triton_out, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dq, triton_dq, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dk, triton_dk, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dv, triton_dv, atol=1e-2, rtol=0.)
@pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("num_splits", [1, 0])
# @pytest.mark.parametrize("num_splits", [0])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@pytest.mark.parametrize("rotary_interleaved", [False, True])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [True])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 1024),
(16, 128 * 1024),
(128, 128),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_kvcache(
seqlen_q,
seqlen_k,
d,
has_batch_idx,
rotary_fraction,
rotary_interleaved,
seqlen_new_eq_seqlen_q,
causal,
local,
new_kv,
mha_type,
num_splits,
dtype,
alibi,
):
if seqlen_q > seqlen_k and new_kv:
pytest.skip()
if not new_kv and rotary_fraction > 0.0:
pytest.skip()
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 2
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
nheads = 8
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 4)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads,
d, device=device, dtype=dtype)
seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(
1, seqlen_q + 1, (1,)).item()
if new_kv:
k = torch.randn(batch_size, seqlen_new, nheads_k,
d, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen_new, nheads_k,
d, device=device, dtype=dtype)
else:
k, v = None, None
k_cache = torch.randn(batch_size_cache, seqlen_k,
nheads_k, d, device=device, dtype=dtype)
v_cache = torch.randn(batch_size_cache, seqlen_k,
nheads_k, d, device=device, dtype=dtype)
cache_seqlens = torch.randint(
0,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(seqlen_k - (seqlen_q if (causal or local)
and rotary_dim > 1 else seqlen_new) + 1)
if new_kv
else (seqlen_k + 1),
(batch_size,),
dtype=torch.int32,
device=device,
)
if has_batch_idx:
cache_batch_idx = torch.randperm(
batch_size_cache, dtype=torch.int32, device=device)[:batch_size]
else:
cache_batch_idx = None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if rotary_dim > 0:
angle = torch.rand(seqlen_k, rotary_dim // 2,
device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if causal or local:
q_ro = apply_rotary_emb(
q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
)
else:
q_ro = rearrange(
apply_rotary_emb(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
seqlen_offsets=cache_seqlens,
interleaved=rotary_interleaved,
),
"b 1 (s h) d -> b s h d",
s=seqlen_q,
)
# q_ro = q
k_ro = apply_rotary_emb(
k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
)
else:
cos, sin = None, None
q_ro, k_ro = q, k
# k_cache[:, 64:] = -1
k_cache_ref = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
v_cache_ref = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
if new_kv:
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
)
k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
k_cache_rep = repeat(
k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_cache_rep = repeat(
v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
if alibi:
seqlen_alibi = k_cache_rep.shape[1]
alibi_tensor, alibi_slopes = generate_alibi(
max_seq_len=seqlen_alibi,
num_attention_heads=nheads,
tp_world_size=1,
tp_index=0,
key_padding_mask=None,
device="cuda"
)
# alibi_tensor = alibi_tensor.expand(batch_size, -1, seqlen_q, -1)
alibi_slopes = repeat(alibi_slopes, "nh -> b nh", b=batch_size)
if alibi_tensor.abs().max().item() >= torch.finfo(dtype).max:
pytest.skip()
else:
alibi_tensor, alibi_slopes = None, None
out = flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k,
v,
cos,
sin,
cache_seqlens,
cache_batch_idx,
causal=causal,
window_size=window_size,
rotary_interleaved=rotary_interleaved,
num_splits=num_splits,
alibi_slopes=alibi_slopes
)
# out = flash_attn_with_kvcache(
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask = arange < cache_seqlens_expanded + \
(seqlen_new if new_kv else 0)
out_ref, _ = attention_ref(
q_ro,
k_cache_rep,
v_cache_rep,
None,
key_padding_mask,
0.0,
None,
causal=causal,
window_size=window_size,
bias=alibi_tensor
)
out_pt, _ = attention_ref(
q_ro,
k_cache_rep,
v_cache_rep,
None,
key_padding_mask,
0.0,
None,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
bias=alibi_tensor
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if new_kv:
k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
assert torch.allclose(k_cache_select, k_cache_ref,
rtol=1e-3, atol=1e-3)
assert torch.equal(v_cache_select, v_cache_ref)
assert (out - out_ref).abs().max().item() <= 3 * \
(out_pt - out_ref).abs().max().item() + 1e-5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment