Commit 5ab9b366 authored by Tri Dao's avatar Tri Dao
Browse files

Clean up alibi, implement non-causal alibi

parent bc28eacc
...@@ -253,12 +253,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -253,12 +253,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // batch_size x num_heads
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
...@@ -297,13 +297,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -297,13 +297,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case // causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
if (is_causal) { window_size_right = 0; } if (is_causal) { window_size_right = 0; }
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
// TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together? const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !(alibi_slopes_.has_value());
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k; const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
...@@ -416,12 +416,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -416,12 +416,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes); CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
} else { } else {
params.has_alibi = false; params.alibi_slopes_ptr = nullptr;
} }
if (seqlen_k > 0) { if (seqlen_k > 0) {
...@@ -456,6 +455,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -456,6 +455,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q, const int max_seqlen_q,
const int max_seqlen_k, const int max_seqlen_k,
const float p_dropout, const float p_dropout,
...@@ -464,7 +464,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -464,7 +464,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // b x num_heads
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
...@@ -612,12 +611,11 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -612,12 +611,11 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes); CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
} else { } else {
params.has_alibi = false; params.alibi_slopes_ptr = nullptr;
} }
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
...@@ -664,12 +662,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -664,12 +662,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // batch_size x num_heads
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) { c10::optional<at::Tensor> &rng_state) {
...@@ -848,12 +846,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -848,12 +846,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes); CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
} else { } else {
params.has_alibi = false; params.alibi_slopes_ptr = nullptr;
} }
if (seqlen_q > 0) { if (seqlen_q > 0) {
...@@ -891,6 +888,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -891,6 +888,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q, const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
...@@ -899,7 +897,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -899,7 +897,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // b x num_heads
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) { c10::optional<at::Tensor> &rng_state) {
...@@ -1094,12 +1091,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1094,12 +1091,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes); CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
} else { } else {
params.has_alibi = false; params.alibi_slopes_ptr = nullptr;
} }
launch(params, stream, /*configure=*/false); launch(params, stream, /*configure=*/false);
...@@ -1128,14 +1124,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1128,14 +1124,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2) c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2) c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits, int num_splits
c10::optional<at::Tensor> &alibi_slopes_ // batch_size x num_heads
) { ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
...@@ -1174,13 +1170,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1174,13 +1170,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case // causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
if (is_causal) { window_size_right = 0; } if (is_causal) { window_size_right = 0; }
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
// TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together? const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !(alibi_slopes_.has_value());
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k; const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
...@@ -1347,12 +1343,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1347,12 +1343,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes); CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
} else { } else {
params.has_alibi = false; params.alibi_slopes_ptr = nullptr;
} }
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
......
...@@ -13,22 +13,32 @@ using namespace cute; ...@@ -13,22 +13,32 @@ using namespace cute;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout> template <bool Is_causal, typename Engine, typename Layout>
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor, inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_, const int col_idx_offset_,
const int max_seqlen_k, const int max_seqlen_k,
const int row_idx_offset_, const int row_idx_offset,
const int max_seqlen_q, const int max_seqlen_q,
const int warp_row_stride, const int warp_row_stride,
const int head_idx,
const float softmax_scale,
const float alibi_slope) { const float alibi_slope) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
const int row_idx_offset = row_idx_offset_;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
const float alibi_slope_unscaled = alibi_slope / softmax_scale; if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride; const int row_idx_base = row_idx_offset + mi * warp_row_stride;
...@@ -41,9 +51,7 @@ inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor, ...@@ -41,9 +51,7 @@ inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j; const int col_idx = col_idx_base + j;
const float alibi = alibi_slope_unscaled * col_idx; tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
if (col_idx < max_seqlen_k && row_idx < max_seqlen_q) {
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi;
} }
} }
} }
......
...@@ -131,10 +131,6 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -131,10 +131,6 @@ struct Flash_fwd_params : public Qkv_params {
int num_splits; // For split-KV version int num_splits; // For split-KV version
// float alibi_start;
// float alibi_ratio;
bool has_alibi;
void * __restrict__ alibi_slopes_ptr; void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride; index_t alibi_slopes_batch_stride;
}; };
......
...@@ -753,8 +753,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -753,8 +753,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
#pragma unroll #pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi)); const int row = get<0>(taccScS_row(mi));
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
} }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
// Tensor tKrK = make_fragment_like(tKsK); // Tensor tKrK = make_fragment_like(tKsK);
// // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK); // // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
...@@ -792,18 +796,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -792,18 +796,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
clear(acc_dv); clear(acc_dv);
clear(acc_dk); clear(acc_dk);
float alibi_slope = 0.0f; float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
}
for (; m_block >= m_block_min; --m_block) { for (; m_block >= m_block_min; --m_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
...@@ -830,14 +823,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -830,14 +823,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread(32, 0)) { print(scores); } // if (cute::thread(32, 0)) { print(scores); }
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)), m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
AtomLayoutMS * 16, AtomLayoutMS * 16,
bidh, params.scale_softmax,
alibi_slope alibi_slope
); );
} }
...@@ -1403,18 +1395,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1403,18 +1395,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
clear(acc_dq); clear(acc_dq);
float alibi_slope = 0.0f; float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
}
for (; n_block >= 0; --n_block) { for (; n_block >= 0; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M_SdP, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M_SdP, MMA_N)
...@@ -1429,14 +1410,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1429,14 +1410,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)), m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
AtomLayoutMS * 16, AtomLayoutMS * 16,
bidh, params.scale_softmax,
alibi_slope alibi_slope
); );
} }
......
...@@ -64,12 +64,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -64,12 +64,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
BOOL_SWITCH(params.has_alibi, Has_alibi, [&] { BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false // If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal && !Is_local, Is_local, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>; auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) { if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute( C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
...@@ -109,7 +108,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -109,7 +108,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(params.has_alibi, Has_alibi, [&] { BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, IsEvenNConst && IsEvenKConst, IsEvenKConst>; auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
......
...@@ -322,28 +322,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -322,28 +322,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
clear(acc_o); clear(acc_o);
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
// For performance reason, we separate out two kinds of iterations: // For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't. // those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration. // We will have at least 1 "masking" iteration.
float alibi_slope = 0.0f;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local) constexpr int n_masking_steps = (!Is_causal && !Is_local)
...@@ -382,14 +368,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -382,14 +368,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// can produce Inf / NaN. // can produce Inf / NaN.
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN, n_block * kBlockN,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
kNWarps * 16, kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope alibi_slope
); );
} }
...@@ -500,14 +485,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -500,14 +485,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN, n_block * kBlockN,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
kNWarps * 16, kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope alibi_slope
); );
} }
...@@ -950,28 +934,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -950,28 +934,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
clear(acc_o); clear(acc_o);
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
// For performance reason, we separate out two kinds of iterations: // For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't. // those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration. // We will have at least 1 "masking" iteration.
float alibi_slope = 0.0f;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local) constexpr int n_masking_steps = (!Is_causal && !Is_local)
...@@ -1006,14 +976,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1006,14 +976,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN, n_block * kBlockN,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
kNWarps * 16, kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope alibi_slope
); );
} }
...@@ -1099,14 +1068,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1099,14 +1068,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN, n_block * kBlockN,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
kNWarps * 16, kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope alibi_slope
); );
} }
......
...@@ -45,7 +45,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -45,7 +45,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
BOOL_SWITCH(params.has_alibi, Has_alibi, [&] { BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// Will only return softmax if dropout, to reduce compilation time. // Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates // If return_softmax, set IsEvenMNConst to false to reduce number of templates
...@@ -86,7 +86,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -86,7 +86,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
BOOL_SWITCH(params.has_alibi, Has_alibi, [&] { BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false // If Is_local, set Is_causal to false
......
...@@ -141,14 +141,12 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_ ...@@ -141,14 +141,12 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
template <bool HasWSLeft=true, typename Engine, typename Layout> template <bool HasWSLeft=true, typename Engine, typename Layout>
inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_, inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride, const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) { const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
// const int row_idx_offset = row_idx_offset_ + lane_id / 4;
const int row_idx_offset = row_idx_offset_;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
...@@ -180,17 +178,17 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in ...@@ -180,17 +178,17 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_, inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride) { const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
max_seqlen_q, warp_row_stride, -1, 0); max_seqlen_q, warp_row_stride, -1, 0);
} }
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx( inline __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol, Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_) const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
{ {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
...@@ -199,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx( ...@@ -199,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) { for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
......
...@@ -53,12 +53,12 @@ def _flash_attn_forward( ...@@ -53,12 +53,12 @@ def _flash_attn_forward(
k, k,
v, v,
None, None,
alibi_slopes,
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
return_softmax, return_softmax,
None, None,
) )
...@@ -90,6 +90,7 @@ def _flash_attn_varlen_forward( ...@@ -90,6 +90,7 @@ def _flash_attn_varlen_forward(
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
None, None,
alibi_slopes,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
dropout_p, dropout_p,
...@@ -98,7 +99,6 @@ def _flash_attn_varlen_forward( ...@@ -98,7 +99,6 @@ def _flash_attn_varlen_forward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
return_softmax, return_softmax,
None, None,
) )
...@@ -137,12 +137,12 @@ def _flash_attn_backward( ...@@ -137,12 +137,12 @@ def _flash_attn_backward(
dq, dq,
dk, dk,
dv, dv,
alibi_slopes,
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
None, None,
rng_state, rng_state,
) )
...@@ -185,6 +185,7 @@ def _flash_attn_varlen_backward( ...@@ -185,6 +185,7 @@ def _flash_attn_varlen_backward(
dv, dv,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
alibi_slopes,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
dropout_p, dropout_p,
...@@ -193,7 +194,6 @@ def _flash_attn_varlen_backward( ...@@ -193,7 +194,6 @@ def _flash_attn_varlen_backward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
None, None,
rng_state, rng_state,
) )
...@@ -613,6 +613,8 @@ def flash_attn_qkvpacked_func( ...@@ -613,6 +613,8 @@ def flash_attn_qkvpacked_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -673,6 +675,9 @@ def flash_attn_kvpacked_func( ...@@ -673,6 +675,9 @@ def flash_attn_kvpacked_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -732,6 +737,9 @@ def flash_attn_func( ...@@ -732,6 +737,9 @@ def flash_attn_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -780,6 +788,8 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -780,6 +788,8 @@ def flash_attn_varlen_qkvpacked_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -858,6 +868,9 @@ def flash_attn_varlen_kvpacked_func( ...@@ -858,6 +868,9 @@ def flash_attn_varlen_kvpacked_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -938,6 +951,9 @@ def flash_attn_varlen_func( ...@@ -938,6 +951,9 @@ def flash_attn_varlen_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -981,8 +997,8 @@ def flash_attn_with_kvcache( ...@@ -981,8 +997,8 @@ def flash_attn_with_kvcache(
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
rotary_interleaved=True, rotary_interleaved=True,
num_splits=0,
alibi_slopes=None, alibi_slopes=None,
num_splits=0,
): ):
""" """
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...@@ -1050,6 +1066,9 @@ def flash_attn_with_kvcache( ...@@ -1050,6 +1066,9 @@ def flash_attn_with_kvcache(
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style). (i.e. GPT-NeoX style).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
num_splits: int. If > 1, split the key/value into this many chunks along the sequence. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits. to automatically determine the number of splits.
...@@ -1080,6 +1099,7 @@ def flash_attn_with_kvcache( ...@@ -1080,6 +1099,7 @@ def flash_attn_with_kvcache(
rotary_cos, rotary_cos,
rotary_sin, rotary_sin,
cache_batch_idx, cache_batch_idx,
alibi_slopes,
None, None,
softmax_scale, softmax_scale,
causal, causal,
...@@ -1087,6 +1107,5 @@ def flash_attn_with_kvcache( ...@@ -1087,6 +1107,5 @@ def flash_attn_with_kvcache(
window_size[1], window_size[1],
rotary_interleaved, rotary_interleaved,
num_splits, num_splits,
alibi_slopes,
) )
return out return out
import math
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn import (flash_attn_func, flash_attn_kvpacked_func,
flash_attn_qkvpacked_func, flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache)
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size
from flash_attn.flash_attn_triton import \
flash_attn_func as flash_attn_func_triton
from flash_attn.layers.rotary import apply_rotary_emb
MAX_HEADDIM_SM8x = 192
is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
def generate_alibi(max_seq_len, num_attention_heads, tp_world_size, tp_index, key_padding_mask=None, device="cuda"):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2 ** (-2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
:n - closest_power_of_2]
slopes = torch.tensor(get_slopes(num_attention_heads)).to(device=device)
# Select the part of the tensor that corresponds to our tensor parallel index.
assert (num_attention_heads/tp_world_size).is_integer(
), "it works only when (num_attention_heads/tp_world_size) is integer"
nh_tp = num_attention_heads // tp_world_size
slopes = slopes[nh_tp * tp_index:nh_tp * (tp_index + 1)]
if (key_padding_mask is None):
arange_tensor = rearrange(torch.arange(max_seq_len), "sqk -> 1 sqk").to(device=device)
else:
arange_tensor = (key_padding_mask.cumsum(dim=-1, dtype=slopes.dtype) - 1) \
.masked_fill_(~key_padding_mask, torch.finfo(torch.float).min).to(device=device)
arange_tensor = rearrange(arange_tensor, 'b sqk -> b 1 1 sqk')
# (1, nheads, 1, seqlen_k) or (batch, nheads, 1, seqlen_k)
alibi_tensor = rearrange(slopes, 'nh -> 1 nh 1 1') * arange_tensor
return alibi_tensor, slopes
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", right_padding=True):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen,
device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
)
elif mode == "third":
lengths = torch.randint(
max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
if right_padding:
padding_mask = (
repeat(torch.arange(max_seqlen, device=device),
"s -> b s", b=batch_size) < lengths
)
else:
padding_mask = (
repeat(torch.arange(start=max_seqlen-1, end=-1, step=-1, device=device),
"s -> b s", b=batch_size) < lengths
)
return padding_mask
def generate_qkv(
q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask)
def output_pad_fn(output_unpad): return pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
)
max_seqlen_q = seqlen_q
def output_pad_fn(output_unpad): return rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(
k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
def dqkv_pad_fn(dqkv_unpad): return pad_input(
dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
def dqkv_pad_fn(dqkv_unpad): return rearrange(
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
def dkv_pad_fn(dkv_unpad): return pad_input(
dkv_unpad, indices_k, batch_size, seqlen_k)
else:
def dkv_pad_fn(dkv_unpad): return rearrange(
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else:
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
def dk_pad_fn(dk_unpad): return pad_input(
dk_unpad, indices_k, batch_size, seqlen_k)
else:
def dk_pad_fn(dk_unpad): return rearrange(
dk_unpad, "(b s) h d -> b s h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
def construct_local_mask(
seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
query_padding_mask=None,
key_padding_mask=None,
device=None,
):
row_idx = rearrange(torch.arange(
seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(
col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
reorder_ops=False,
bias=None
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if bias is not None:
bias = bias.to(scores.dtype)
scores += bias
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask,
"b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
attention = torch.softmax(scores, dim=-1)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(
torch.all(local_mask, dim=-1, keepdim=True), 0.0)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum(
"bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(
rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def attention_kvpacked_ref(
q,
kv,
query_padding_mask=None,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
reorder_ops=False,
):
return attention_ref(
q,
kv[:, :, 0],
kv[:, :, 1],
query_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
window_size=window_size,
reorder_ops=reorder_ops,
)
def attention_qkvpacked_ref(
qkv,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
reorder_ops=False,
):
return attention_ref(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
key_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
window_size=window_size,
reorder_ops=reorder_ops,
)
def generate_sparsity_mask(seqlen, sparsity=0.3):
repeats = seqlen // 16 // 2
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
nrow, ncol = seqlen // 16, seqlen // 256
mask = torch.rand(nrow, ncol, device="cuda") < sparsity
return mask
def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
blockmask: (seqlen / 16, seqlen / 256)
attn_mask: (batch_size, seqlen)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen, seqlen)
Output:
output: (batch_size, seqlen, nheads, head_dim)
attention: softmax after dropout
"""
q, k, v = qkv.float().unbind(dim=2)
d = qkv.shape[-1]
seqlen = qkv.shape[1]
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)")
blockmask = blockmask[:seqlen, :seqlen]
scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf"))
attention = torch.softmax(scores, dim=-1)
attention = attention.masked_fill(
rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0)
attention = attention.masked_fill_(
rearrange(~blockmask, "t s -> 1 1 t s"), 0.0)
attention_drop = attention.masked_fill(
~dropout_mask, 0.0) / (1 - dropout_p)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0)
return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
def convert_flash_attn_S_to_softmax(
S,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
head_dim,
is_dropout,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""FlashAttention stores the S matrix in a different way.
Arguments:
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
query_padding_mask: (batch_size, seqlen_q_rounded)
key_padding_mask: (batch_size, seqlen_k_rounded)
"""
if causal:
window_size = (window_size[0], 0)
seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
warps_n = 4
blocksize_m, blocksize_n = _get_block_size(
S.device, head_dim, is_dropout, causal)
nblocks_n = (seqlen_k_rounded + blocksize_n - 1) // blocksize_n
nblocks_m = (seqlen_q_rounded + blocksize_m - 1) // blocksize_m
mmas_n = (blocksize_n + 16 - 1) // 16
S_flat = rearrange(
S,
"b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)",
blocksize_m=blocksize_m,
blocksize_n=blocksize_n,
)
S_converted = rearrange(
S_flat,
"b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)",
mmas_n=mmas_n,
warps_n=warps_n,
eight=8,
c0=2,
c1=2,
c2=2,
four=4,
)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
S.device,
)
local_mask = F.pad(
local_mask,
(0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
value=True,
)
S_converted.masked_fill_(local_mask, 0.0)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
seqlen_q_og = (
query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
)
if query_padding_mask is not None:
query_padding_mask = F.pad(
query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
S_converted = S_converted.masked_fill(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
if key_padding_mask is not None:
key_padding_mask = F.pad(
key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
S_converted = S_converted.masked_fill(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
return S_converted[:, :, :seqlen_q, :seqlen_k]
def normalize_flash_attn_S(
attn_unnorm,
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
is_dropout=False,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim)
key_padding_mask: (batch_size, seqlen_q)
Output:
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
"""
if causal:
window_size = (window_size[0], 0)
q, k, v = q.float(), k.float(), v.float()
_, seqlen_q, _, head_dim = q.shape
seqlen_k = k.shape[1]
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask,
"b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
_, block_size_n = _get_block_size(
scores.device, head_dim, is_dropout, causal)
scores_block = scores.split(block_size_n, dim=-1)
lse_block = torch.stack([torch.logsumexp(s, dim=-1)
for s in scores_block], dim=-1)
lse = torch.logsumexp(lse_block, dim=-1)
# lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
# so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
lse[lse == float("-inf")] = float("inf")
scores_max_block = torch.stack(
[torch.amax(s, dim=-1) for s in scores_block], dim=-1)
cummax_block = torch.cummax(
scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
attn_norm = torch.cat(
[
a * rearrange(torch.exp(m - lse), "b h s -> b h s 1")
for a, m in zip(attn_unnorm_block, cummax_block)
],
dim=-1,
)
if query_padding_mask is not None:
attn_norm.masked_fill_(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
return attn_norm.to(dtype=attn_unnorm.dtype)
def get_dropout_fraction(
dropout_mask,
query_padding_mask=None,
key_padding_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
"""
if causal:
window_size = (window_size[0], 0)
batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
dropped = ~dropout_mask
valid = torch.ones_like(dropout_mask)
if query_padding_mask is not None:
dropped.masked_fill_(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
valid.masked_fill_(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
if key_padding_mask is not None:
dropped.masked_fill_(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
valid.masked_fill_(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
dropout_mask.device,
)
dropped.masked_fill_(local_mask, False)
valid.masked_fill_(local_mask, False)
dropped_total = dropped.sum()
return dropped.sum() / valid.sum()
@pytest.mark.parametrize(
"dtype", [torch.float16]
)
@pytest.mark.parametrize(
"b_sq",
[
(32, 512),
(16, 1024),
(8, 2048),
(4, 4096),
(2, 8192),
(1, 16384)
]
)
@pytest.mark.parametrize(
"nh_hd",
[
(32, 64),
(16, 128),
(40, 128) # non power of 2 nh
]
)
@pytest.mark.parametrize(
"tp_world_size", [1, 2, 4]
)
def test_flash_attn_func(b_sq, nh_hd, tp_world_size, dtype):
b, sq = b_sq
nh, hd = nh_hd
nh_tp = nh // tp_world_size
q, k, v = [torch.randn(b, sq, nh_tp, hd, device="cuda",
dtype=dtype, requires_grad=True) for _ in range(3)]
dout = torch.rand_like(q)
for tp_index in range(tp_world_size):
alibi, alibi_slopes = generate_alibi(
max_seq_len=sq,
num_attention_heads=nh,
tp_world_size=tp_world_size,
tp_index=tp_index,
key_padding_mask=None,
device="cuda"
)
triton_out = flash_attn_func_triton(
q, k, v, alibi, True, hd**(-0.5))
triton_out.backward(dout)
triton_dq, q.grad = q.grad.clone(), None
triton_dk, k.grad = k.grad.clone(), None
triton_dv, v.grad = v.grad.clone(), None
flash_out = flash_attn_func(q, k, v, causal=True, alibi_slopes=repeat(alibi_slopes, "nh -> b nh", b=b))
flash_out.backward(dout)
flash_dq, q.grad = q.grad.clone(), None
flash_dk, k.grad = k.grad.clone(), None
flash_dv, v.grad = v.grad.clone(), None
assert torch.allclose(flash_out, triton_out, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dq, triton_dq, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dk, triton_dk, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dv, triton_dv, atol=1e-2, rtol=0.)
@pytest.mark.parametrize(
"dtype", [torch.float16]
)
@pytest.mark.parametrize(
"right_padding", [True, False]
)
@pytest.mark.parametrize(
"b_sq",
[
(32, 512),
(16, 1024),
(8, 2048),
(4, 4096),
(2, 8192),
(1, 16384)
]
)
@pytest.mark.parametrize(
"nh_hd",
[
(32, 64),
(16, 128),
(40, 128) # non power of 2 nh
]
)
@pytest.mark.parametrize(
"tp_world_size", [1, 2, 4]
)
def test_flash_attn_varlen_func(b_sq, nh_hd, tp_world_size, right_padding, dtype):
b, sqk = b_sq
nh, hd = nh_hd
nh_tp = nh // tp_world_size
# flash_attn_func_triton(), flash-attention v2 (above v2.1) causal logic are different
# so only (seqlen_q == 1, causal=False to triton ver.) shows correct results
# https://github.com/huggingface/text-generation-inference/blob/v1.1.1/server/text_generation_server/models/custom_modeling/mpt_modeling.py#L53-L63
q = torch.randn(b, 1, nh_tp, hd, device="cuda", dtype=dtype, requires_grad=True)
k, v = [torch.randn(b, sqk, nh_tp, hd, device="cuda",
dtype=dtype, requires_grad=True) for _ in range(2)]
dout = torch.rand_like(q)
padding_mask = generate_random_padding_mask(sqk, b, "cuda", "random", right_padding)
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, None, padding_mask, kvpacked=False)
for tp_index in range(tp_world_size):
alibi, alibi_slopes = generate_alibi(
max_seq_len=sqk,
num_attention_heads=nh,
tp_world_size=tp_world_size,
tp_index=tp_index,
key_padding_mask=padding_mask,
device="cuda"
)
triton_out = flash_attn_func_triton(
q, k, v, alibi, False, hd**(-0.5))
triton_out.backward(dout)
triton_dq, q.grad = q.grad.clone(), None
triton_dk, k.grad = k.grad.clone(), None
triton_dv, v.grad = v.grad.clone(), None
flash_out_unpad = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
causal=True,
alibi_slopes=repeat(alibi_slopes, "nh -> b nh", b=b)
)
flash_out = output_pad_fn(flash_out_unpad)
flash_out.backward(dout)
flash_dq_unpad, q_unpad.grad = q_unpad.grad.clone(), None
flash_dk_unpad, k_unpad.grad = k_unpad.grad.clone(), None
flash_dv_unpad, v_unpad.grad = v_unpad.grad.clone(), None
flash_dq = dq_pad_fn(flash_dq_unpad)
flash_dk = dk_pad_fn(flash_dk_unpad)
flash_dv = dk_pad_fn(flash_dv_unpad)
assert torch.allclose(flash_out, triton_out, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dq, triton_dq, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dk, triton_dk, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dv, triton_dv, atol=1e-2, rtol=0.)
@pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("num_splits", [1, 0])
# @pytest.mark.parametrize("num_splits", [0])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@pytest.mark.parametrize("rotary_interleaved", [False, True])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [True])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 1024),
(16, 128 * 1024),
(128, 128),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_kvcache(
seqlen_q,
seqlen_k,
d,
has_batch_idx,
rotary_fraction,
rotary_interleaved,
seqlen_new_eq_seqlen_q,
causal,
local,
new_kv,
mha_type,
num_splits,
dtype,
alibi,
):
if seqlen_q > seqlen_k and new_kv:
pytest.skip()
if not new_kv and rotary_fraction > 0.0:
pytest.skip()
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 2
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
nheads = 8
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 4)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads,
d, device=device, dtype=dtype)
seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(
1, seqlen_q + 1, (1,)).item()
if new_kv:
k = torch.randn(batch_size, seqlen_new, nheads_k,
d, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen_new, nheads_k,
d, device=device, dtype=dtype)
else:
k, v = None, None
k_cache = torch.randn(batch_size_cache, seqlen_k,
nheads_k, d, device=device, dtype=dtype)
v_cache = torch.randn(batch_size_cache, seqlen_k,
nheads_k, d, device=device, dtype=dtype)
cache_seqlens = torch.randint(
0,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(seqlen_k - (seqlen_q if (causal or local)
and rotary_dim > 1 else seqlen_new) + 1)
if new_kv
else (seqlen_k + 1),
(batch_size,),
dtype=torch.int32,
device=device,
)
if has_batch_idx:
cache_batch_idx = torch.randperm(
batch_size_cache, dtype=torch.int32, device=device)[:batch_size]
else:
cache_batch_idx = None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if rotary_dim > 0:
angle = torch.rand(seqlen_k, rotary_dim // 2,
device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if causal or local:
q_ro = apply_rotary_emb(
q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
)
else:
q_ro = rearrange(
apply_rotary_emb(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
seqlen_offsets=cache_seqlens,
interleaved=rotary_interleaved,
),
"b 1 (s h) d -> b s h d",
s=seqlen_q,
)
# q_ro = q
k_ro = apply_rotary_emb(
k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
)
else:
cos, sin = None, None
q_ro, k_ro = q, k
# k_cache[:, 64:] = -1
k_cache_ref = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
v_cache_ref = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
if new_kv:
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
)
k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
k_cache_rep = repeat(
k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_cache_rep = repeat(
v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
if alibi:
seqlen_alibi = k_cache_rep.shape[1]
alibi_tensor, alibi_slopes = generate_alibi(
max_seq_len=seqlen_alibi,
num_attention_heads=nheads,
tp_world_size=1,
tp_index=0,
key_padding_mask=None,
device="cuda"
)
# alibi_tensor = alibi_tensor.expand(batch_size, -1, seqlen_q, -1)
alibi_slopes = repeat(alibi_slopes, "nh -> b nh", b=batch_size)
if alibi_tensor.abs().max().item() >= torch.finfo(dtype).max:
pytest.skip()
else:
alibi_tensor, alibi_slopes = None, None
out = flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k,
v,
cos,
sin,
cache_seqlens,
cache_batch_idx,
causal=causal,
window_size=window_size,
rotary_interleaved=rotary_interleaved,
num_splits=num_splits,
alibi_slopes=alibi_slopes
)
# out = flash_attn_with_kvcache(
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask = arange < cache_seqlens_expanded + \
(seqlen_new if new_kv else 0)
out_ref, _ = attention_ref(
q_ro,
k_cache_rep,
v_cache_rep,
None,
key_padding_mask,
0.0,
None,
causal=causal,
window_size=window_size,
bias=alibi_tensor
)
out_pt, _ = attention_ref(
q_ro,
k_cache_rep,
v_cache_rep,
None,
key_padding_mask,
0.0,
None,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
bias=alibi_tensor
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if new_kv:
k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
assert torch.allclose(k_cache_select, k_cache_ref,
rtol=1e-3, atol=1e-3)
assert torch.equal(v_cache_select, v_cache_ref)
assert (out - out_ref).abs().max().item() <= 3 * \
(out_pt - out_ref).abs().max().item() + 1e-5
...@@ -26,6 +26,31 @@ is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0) ...@@ -26,6 +26,31 @@ is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0) is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
def attn_bias_from_alibi_slopes(
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False
):
batch, nheads = slopes.shape
device = slopes.device
slopes = rearrange(slopes, "b h -> b h 1 1")
if causal:
return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes
else:
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
relative_pos = torch.abs(row_idx + sk - sq - col_idx)
return -slopes * relative_pos.to(dtype=slopes.dtype)
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
assert mode in ["full", "random", "third"] assert mode in ["full", "random", "third"]
if mode == "full": if mode == "full":
...@@ -186,6 +211,7 @@ def attention_ref( ...@@ -186,6 +211,7 @@ def attention_ref(
v, v,
query_padding_mask=None, query_padding_mask=None,
key_padding_mask=None, key_padding_mask=None,
attn_bias=None,
dropout_p=0.0, dropout_p=0.0,
dropout_mask=None, dropout_mask=None,
causal=False, causal=False,
...@@ -200,6 +226,7 @@ def attention_ref( ...@@ -200,6 +226,7 @@ def attention_ref(
v: (batch_size, seqlen_k, nheads_k, head_dim) v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q) query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k) key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking causal: whether to apply causal masking
...@@ -238,7 +265,9 @@ def attention_ref( ...@@ -238,7 +265,9 @@ def attention_ref(
q.device, q.device,
) )
scores.masked_fill_(local_mask, float("-inf")) scores.masked_fill_(local_mask, float("-inf"))
attention = torch.softmax(scores, dim=-1) if attn_bias is not None:
scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# Some rows might be completely masked out so we fill them with zero instead of NaN # Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0: if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
...@@ -264,6 +293,7 @@ def attention_kvpacked_ref( ...@@ -264,6 +293,7 @@ def attention_kvpacked_ref(
kv, kv,
query_padding_mask=None, query_padding_mask=None,
key_padding_mask=None, key_padding_mask=None,
attn_bias=None,
dropout_p=0.0, dropout_p=0.0,
dropout_mask=None, dropout_mask=None,
causal=False, causal=False,
...@@ -277,6 +307,7 @@ def attention_kvpacked_ref( ...@@ -277,6 +307,7 @@ def attention_kvpacked_ref(
kv[:, :, 1], kv[:, :, 1],
query_padding_mask, query_padding_mask,
key_padding_mask, key_padding_mask,
attn_bias,
dropout_p, dropout_p,
dropout_mask, dropout_mask,
upcast=upcast, upcast=upcast,
...@@ -289,6 +320,7 @@ def attention_kvpacked_ref( ...@@ -289,6 +320,7 @@ def attention_kvpacked_ref(
def attention_qkvpacked_ref( def attention_qkvpacked_ref(
qkv, qkv,
key_padding_mask=None, key_padding_mask=None,
attn_bias=None,
dropout_p=0.0, dropout_p=0.0,
dropout_mask=None, dropout_mask=None,
causal=False, causal=False,
...@@ -302,6 +334,7 @@ def attention_qkvpacked_ref( ...@@ -302,6 +334,7 @@ def attention_qkvpacked_ref(
qkv[:, :, 2], qkv[:, :, 2],
key_padding_mask, key_padding_mask,
key_padding_mask, key_padding_mask,
attn_bias,
dropout_p, dropout_p,
dropout_mask, dropout_mask,
upcast=upcast, upcast=upcast,
...@@ -436,6 +469,7 @@ def normalize_flash_attn_S( ...@@ -436,6 +469,7 @@ def normalize_flash_attn_S(
v, v,
query_padding_mask=None, query_padding_mask=None,
key_padding_mask=None, key_padding_mask=None,
attn_bias=None,
is_dropout=False, is_dropout=False,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite window size window_size=(-1, -1), # -1 means infinite window size
...@@ -445,6 +479,7 @@ def normalize_flash_attn_S( ...@@ -445,6 +479,7 @@ def normalize_flash_attn_S(
q: (batch_size, seqlen_q, nheads, head_dim) q: (batch_size, seqlen_q, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim) k, v: (batch_size, seqlen_k, nheads, head_dim)
key_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_q)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
Output: Output:
softmax_lse: (batch_size, nheads, seqlen_q) softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q) softmax_max: (batch_size, nheads, seqlen_q)
...@@ -467,6 +502,8 @@ def normalize_flash_attn_S( ...@@ -467,6 +502,8 @@ def normalize_flash_attn_S(
q.device, q.device,
) )
scores.masked_fill_(local_mask, float("-inf")) scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias.to(dtype=scores.dtype)
_, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal) _, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal)
scores_block = scores.split(block_size_n, dim=-1) scores_block = scores.split(block_size_n, dim=-1)
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
...@@ -529,6 +566,8 @@ def get_dropout_fraction( ...@@ -529,6 +566,8 @@ def get_dropout_fraction(
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16]) # @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True]) # @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
...@@ -538,24 +577,34 @@ def get_dropout_fraction( ...@@ -538,24 +577,34 @@ def get_dropout_fraction(
# @pytest.mark.parametrize('d', [32, 64, 96, 128]) # @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize("d", [64]) # @pytest.mark.parametrize("d", [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize("seqlen", [128]) # @pytest.mark.parametrize("seqlen", [97])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.0]) # @pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype): def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 13 batch_size = 8
nheads = 9 nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
qkv = torch.randn( qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
) )
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal)
else:
alibi_slopes, attn_bias = None, None
out, lse, S_dmask = flash_attn_qkvpacked_func( out, lse, S_dmask = flash_attn_qkvpacked_func(
qkv, dropout_p, causal=causal, window_size=window_size, return_attn_probs=True qkv,
dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_attn_probs=True,
) )
if dropout_p > 0.0: if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask_converted = convert_flash_attn_S_to_softmax(
...@@ -578,6 +627,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype): ...@@ -578,6 +627,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
qkv[:, :, 2], qkv[:, :, 2],
None, None,
None, None,
attn_bias,
dropout_p > 0.0, dropout_p > 0.0,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
...@@ -590,11 +640,12 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype): ...@@ -590,11 +640,12 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
dropout_mask = None dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref( out_ref, attn_ref = attention_qkvpacked_ref(
qkv, None, dropout_p, dropout_mask, causal=causal, window_size=window_size qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size
) )
out_pt, attn_pt = attention_qkvpacked_ref( out_pt, attn_pt = attention_qkvpacked_ref(
qkv, qkv,
None, None,
attn_bias,
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
...@@ -651,6 +702,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype): ...@@ -651,6 +702,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
if dropout_p > 0.0: if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
...@@ -659,18 +712,20 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype): ...@@ -659,18 +712,20 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, dtype):
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True]) # @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype): def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = "cuda" device = "cuda"
...@@ -685,6 +740,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype) ...@@ -685,6 +740,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random")
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(
alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal
)
else:
alibi_slopes, attn_bias = None, None
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
*qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True
...@@ -697,6 +759,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype) ...@@ -697,6 +759,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
dropout_p, dropout_p,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_attn_probs=True, return_attn_probs=True,
) )
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
...@@ -721,6 +784,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype) ...@@ -721,6 +784,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
qkv[:, :, 2], qkv[:, :, 2],
key_padding_mask, key_padding_mask,
key_padding_mask, key_padding_mask,
attn_bias,
dropout_p > 0.0, dropout_p > 0.0,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
...@@ -733,11 +797,18 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype) ...@@ -733,11 +797,18 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
dropout_mask = None dropout_mask = None
out_ref, attn_ref = attention_qkvpacked_ref( out_ref, attn_ref = attention_qkvpacked_ref(
qkv, key_padding_mask, dropout_p, dropout_mask, causal=causal, window_size=window_size qkv,
key_padding_mask,
attn_bias,
dropout_p,
dropout_mask,
causal=causal,
window_size=window_size,
) )
out_pt, attn_pt = attention_qkvpacked_ref( out_pt, attn_pt = attention_qkvpacked_ref(
qkv, qkv,
key_padding_mask, key_padding_mask,
attn_bias,
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
...@@ -774,6 +845,8 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype) ...@@ -774,6 +845,8 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
if dropout_p > 0.0: if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
...@@ -786,11 +859,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype) ...@@ -786,11 +859,13 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
# @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True]) # @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
...@@ -815,7 +890,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype) ...@@ -815,7 +890,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, dtype)
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.17]) # @pytest.mark.parametrize("dropout_p", [0.17])
def test_flash_attn_output( def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, mha_type, dtype, kvpacked seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, mha_type, dtype, kvpacked
): ):
if ( if (
max(seqlen_q, seqlen_k) >= 2048 max(seqlen_q, seqlen_k) >= 2048
...@@ -825,7 +900,7 @@ def test_flash_attn_output( ...@@ -825,7 +900,7 @@ def test_flash_attn_output(
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 13 batch_size = 8
nheads = 9 nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0 assert nheads % nheads_k == 0
...@@ -842,14 +917,32 @@ def test_flash_attn_output( ...@@ -842,14 +917,32 @@ def test_flash_attn_output(
v = torch.randn( v = torch.randn(
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True
) )
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
else:
alibi_slopes, attn_bias = None, None
if kvpacked: if kvpacked:
out, lse, S_dmask = flash_attn_kvpacked_func( out, lse, S_dmask = flash_attn_kvpacked_func(
q, kv, dropout_p, causal=causal, window_size=window_size, return_attn_probs=True q,
kv,
dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_attn_probs=True,
) )
else: else:
out, lse, S_dmask = flash_attn_func( out, lse, S_dmask = flash_attn_func(
q, k, v, dropout_p, causal=causal, window_size=window_size, return_attn_probs=True q,
k,
v,
dropout_p,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_attn_probs=True,
) )
if dropout_p > 0.0: if dropout_p > 0.0:
S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask_converted = convert_flash_attn_S_to_softmax(
...@@ -878,6 +971,7 @@ def test_flash_attn_output( ...@@ -878,6 +971,7 @@ def test_flash_attn_output(
v_rep, v_rep,
None, None,
None, None,
attn_bias,
dropout_p > 0.0, dropout_p > 0.0,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
...@@ -895,6 +989,7 @@ def test_flash_attn_output( ...@@ -895,6 +989,7 @@ def test_flash_attn_output(
kv, kv,
None, None,
None, None,
attn_bias,
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
...@@ -905,6 +1000,7 @@ def test_flash_attn_output( ...@@ -905,6 +1000,7 @@ def test_flash_attn_output(
kv, kv,
None, None,
None, None,
attn_bias,
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
...@@ -919,6 +1015,7 @@ def test_flash_attn_output( ...@@ -919,6 +1015,7 @@ def test_flash_attn_output(
v, v,
None, None,
None, None,
attn_bias,
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
...@@ -930,6 +1027,7 @@ def test_flash_attn_output( ...@@ -930,6 +1027,7 @@ def test_flash_attn_output(
v, v,
None, None,
None, None,
attn_bias,
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
...@@ -1000,6 +1098,8 @@ def test_flash_attn_output( ...@@ -1000,6 +1098,8 @@ def test_flash_attn_output(
if dropout_p > 0.0: if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
...@@ -1014,11 +1114,13 @@ def test_flash_attn_output( ...@@ -1014,11 +1114,13 @@ def test_flash_attn_output(
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize('mha_type', ["mqa"]) # @pytest.mark.parametrize('mha_type', ["mqa"])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True]) # @pytest.mark.parametrize("local", [True])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [True]) # @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1041,7 +1143,7 @@ def test_flash_attn_output( ...@@ -1041,7 +1143,7 @@ def test_flash_attn_output(
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_output( def test_flash_attn_varlen_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, mha_type, dtype, kvpacked seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, mha_type, dtype, kvpacked
): ):
if ( if (
max(seqlen_q, seqlen_k) >= 2048 max(seqlen_q, seqlen_k) >= 2048
...@@ -1051,7 +1153,7 @@ def test_flash_attn_varlen_output( ...@@ -1051,7 +1153,7 @@ def test_flash_attn_varlen_output(
device = "cuda" device = "cuda"
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 13 batch_size = 8
nheads = 9 nheads = 9
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
assert nheads % nheads_k == 0 assert nheads % nheads_k == 0
...@@ -1072,6 +1174,13 @@ def test_flash_attn_varlen_output( ...@@ -1072,6 +1174,13 @@ def test_flash_attn_varlen_output(
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(
alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal
)
else:
alibi_slopes, attn_bias = None, None
if kvpacked: if kvpacked:
( (
...@@ -1095,9 +1204,10 @@ def test_flash_attn_varlen_output( ...@@ -1095,9 +1204,10 @@ def test_flash_attn_varlen_output(
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
dropout_p, dropout_p,
return_attn_probs=True,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_attn_probs=True,
) )
else: else:
( (
...@@ -1124,9 +1234,10 @@ def test_flash_attn_varlen_output( ...@@ -1124,9 +1234,10 @@ def test_flash_attn_varlen_output(
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
dropout_p, dropout_p,
return_attn_probs=True,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_attn_probs=True,
) )
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
if dropout_p > 0.0: if dropout_p > 0.0:
...@@ -1156,6 +1267,7 @@ def test_flash_attn_varlen_output( ...@@ -1156,6 +1267,7 @@ def test_flash_attn_varlen_output(
v_rep, v_rep,
query_padding_mask, query_padding_mask,
key_padding_mask, key_padding_mask,
attn_bias,
dropout_p > 0.0, dropout_p > 0.0,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
...@@ -1177,6 +1289,7 @@ def test_flash_attn_varlen_output( ...@@ -1177,6 +1289,7 @@ def test_flash_attn_varlen_output(
kv, kv,
query_padding_mask, query_padding_mask,
key_padding_mask, key_padding_mask,
attn_bias,
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
...@@ -1187,6 +1300,7 @@ def test_flash_attn_varlen_output( ...@@ -1187,6 +1300,7 @@ def test_flash_attn_varlen_output(
kv, kv,
query_padding_mask, query_padding_mask,
key_padding_mask, key_padding_mask,
attn_bias,
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
...@@ -1201,6 +1315,7 @@ def test_flash_attn_varlen_output( ...@@ -1201,6 +1315,7 @@ def test_flash_attn_varlen_output(
v, v,
query_padding_mask, query_padding_mask,
key_padding_mask, key_padding_mask,
attn_bias,
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
...@@ -1212,6 +1327,7 @@ def test_flash_attn_varlen_output( ...@@ -1212,6 +1327,7 @@ def test_flash_attn_varlen_output(
v, v,
query_padding_mask, query_padding_mask,
key_padding_mask, key_padding_mask,
attn_bias,
dropout_p, dropout_p,
dropout_mask, dropout_mask,
causal=causal, causal=causal,
...@@ -1284,12 +1400,14 @@ def test_flash_attn_varlen_output( ...@@ -1284,12 +1400,14 @@ def test_flash_attn_varlen_output(
if dropout_p > 0.0: if dropout_p > 0.0:
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
...@@ -1332,7 +1450,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): ...@@ -1332,7 +1450,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
causal = True causal = True
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 13 batch_size = 8
nheads = 9 nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
...@@ -1340,7 +1458,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): ...@@ -1340,7 +1458,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size) out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size)
out_ref, attn_ref = attention_ref( out_ref, attn_ref = attention_ref(
q, k, v, None, None, 0.0, None, causal=causal, window_size=window_size q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size
) )
out_pt, attn_pt = attention_ref( out_pt, attn_pt = attention_ref(
q, q,
...@@ -1348,6 +1466,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): ...@@ -1348,6 +1466,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
v, v,
None, None,
None, None,
None,
0.0, 0.0,
None, None,
causal=causal, causal=causal,
...@@ -1442,7 +1561,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1442,7 +1561,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
causal = True causal = True
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 13 batch_size = 8
nheads = 9 nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
...@@ -1484,6 +1603,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1484,6 +1603,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
v, v,
query_padding_mask, query_padding_mask,
key_padding_mask, key_padding_mask,
None,
0.0, 0.0,
None, None,
causal=causal, causal=causal,
...@@ -1495,6 +1615,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1495,6 +1615,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
v, v,
query_padding_mask, query_padding_mask,
key_padding_mask, key_padding_mask,
None,
0.0, 0.0,
None, None,
causal=causal, causal=causal,
...@@ -1554,8 +1675,10 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1554,8 +1675,10 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16]) # @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [True]) # @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
...@@ -1581,7 +1704,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ...@@ -1581,7 +1704,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
], ],
) )
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, dtype):
if swap_sq_sk: if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda" device = "cuda"
...@@ -1593,11 +1716,23 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt ...@@ -1593,11 +1716,23 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal)
else:
alibi_slopes, attn_bias = None, None
out, lse, _ = flash_attn_func( out, lse, _ = flash_attn_func(
q, k, v, 0.0, causal=causal, window_size=window_size, return_attn_probs=True q,
k,
v,
0.0,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_attn_probs=True,
) )
out_ref, attn_ref = attention_ref( out_ref, attn_ref = attention_ref(
q, k, v, None, None, 0.0, None, causal=causal, window_size=window_size q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size
) )
out_pt, attn_pt = attention_ref( out_pt, attn_pt = attention_ref(
q, q,
...@@ -1605,6 +1740,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt ...@@ -1605,6 +1740,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
v, v,
None, None,
None, None,
attn_bias,
0.0, 0.0,
None, None,
causal=causal, causal=causal,
...@@ -1653,24 +1789,27 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt ...@@ -1653,24 +1789,27 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
# of a Pytorch implementation. # of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
mult = 2 if not alibi else 8
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 2e-4 assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4 assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4 assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("num_splits", [1, 0]) @pytest.mark.parametrize("num_splits", [1, 0])
# @pytest.mark.parametrize("num_splits", [0]) # @pytest.mark.parametrize("num_splits", [1])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("new_kv", [False, True]) @pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [True]) # @pytest.mark.parametrize("new_kv", [False])
@pytest.mark.parametrize("alibi", [False, True])
# @pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("local", [False, True])
# @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@pytest.mark.parametrize("rotary_interleaved", [False, True]) @pytest.mark.parametrize("rotary_interleaved", [False, True])
...@@ -1678,7 +1817,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt ...@@ -1678,7 +1817,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0])
@pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [True]) # @pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
...@@ -1711,6 +1850,7 @@ def test_flash_attn_kvcache( ...@@ -1711,6 +1850,7 @@ def test_flash_attn_kvcache(
seqlen_new_eq_seqlen_q, seqlen_new_eq_seqlen_q,
causal, causal,
local, local,
alibi,
new_kv, new_kv,
mha_type, mha_type,
num_splits, num_splits,
...@@ -1750,10 +1890,22 @@ def test_flash_attn_kvcache( ...@@ -1750,10 +1890,22 @@ def test_flash_attn_kvcache(
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
if has_batch_idx: if has_batch_idx:
cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[:batch_size] cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
:batch_size
]
else: else:
cache_batch_idx = None cache_batch_idx = None
if alibi:
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3
attn_bias = attn_bias_from_alibi_slopes(
alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal
)
else:
alibi_slopes, attn_bias = None, None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if rotary_dim > 0: if rotary_dim > 0:
angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi
...@@ -1785,8 +1937,6 @@ def test_flash_attn_kvcache( ...@@ -1785,8 +1937,6 @@ def test_flash_attn_kvcache(
# k_cache[:, 64:] = -1 # k_cache[:, 64:] = -1
k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
if new_kv: if new_kv:
update_mask = torch.logical_and( update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
...@@ -1808,6 +1958,7 @@ def test_flash_attn_kvcache( ...@@ -1808,6 +1958,7 @@ def test_flash_attn_kvcache(
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
rotary_interleaved=rotary_interleaved, rotary_interleaved=rotary_interleaved,
alibi_slopes=alibi_slopes,
num_splits=num_splits, num_splits=num_splits,
) )
# out = flash_attn_with_kvcache( # out = flash_attn_with_kvcache(
...@@ -1820,13 +1971,13 @@ def test_flash_attn_kvcache( ...@@ -1820,13 +1971,13 @@ def test_flash_attn_kvcache(
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1) # probs = torch.softmax(qk, dim=-1)
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
out_ref, _ = attention_ref( out_ref, _ = attention_ref(
q_ro, q_ro,
k_cache_rep, k_cache_rep,
v_cache_rep, v_cache_rep,
None, None,
key_padding_mask, key_padding_mask,
attn_bias,
0.0, 0.0,
None, None,
causal=causal, causal=causal,
...@@ -1838,6 +1989,7 @@ def test_flash_attn_kvcache( ...@@ -1838,6 +1989,7 @@ def test_flash_attn_kvcache(
v_cache_rep, v_cache_rep,
None, None,
key_padding_mask, key_padding_mask,
attn_bias,
0.0, 0.0,
None, None,
causal=causal, causal=causal,
...@@ -1857,7 +2009,8 @@ def test_flash_attn_kvcache( ...@@ -1857,7 +2009,8 @@ def test_flash_attn_kvcache(
v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx] v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)
assert torch.equal(v_cache_select, v_cache_ref) assert torch.equal(v_cache_select, v_cache_ref)
assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5 mult = 3 if not alibi else 5
assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment