Unverified Commit e4f726fc authored by Sanghun Cho's avatar Sanghun Cho Committed by GitHub
Browse files

Support alibi, by Sanghun Cho from Kakao Brain



* hard-code alibi in fwd

* use params.h as hun_heads

* hard-code alibi in bwd

* add alibi on/off option

* compute alibi_start, ratio outside of kernels

* fix minor merge conflict

* add test_alibi.py

* change apply_alibi() location before masking

* add alibi in splitkv kernel

* fix backward func # of returns

* add out-of-bound check in apply_alibi()

* update test_alibi.py

* update test_alibi.py for kvcache

* simplify alibi parameter interface

* fix performance issue
by computing alibi outside of branch

* update test_flash_attn_varlen_func() for left padding

* implement alibi_slopes (b, nh) loading

* optimize apply_alibi() a bit

* update test cases for alibi_slopes loading

* reflect stylistic comments

* disable "seqlenq_ngroups_swapped" when using alibi

---------
Co-authored-by: default avatarmonk.detective <monk.detective@kakaobrain.com>
parent cd089597
...@@ -258,6 +258,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -258,6 +258,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
bool is_causal, bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // batch_size x num_heads
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
...@@ -301,7 +302,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -301,7 +302,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0; // TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !(alibi_slopes_.has_value());
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k; const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
...@@ -409,6 +411,19 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -409,6 +411,19 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
params.philox_args = gen->philox_cuda_state(counter_offset); params.philox_args = gen->philox_cuda_state(counter_offset);
} }
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
} else {
params.has_alibi = false;
}
if (seqlen_k > 0) { if (seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream); run_mha_fwd(params, stream);
...@@ -449,6 +464,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -449,6 +464,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // b x num_heads
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
...@@ -591,6 +607,19 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -591,6 +607,19 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
params.philox_args = gen->philox_cuda_state(counter_offset); params.philox_args = gen->philox_cuda_state(counter_offset);
} }
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
} else {
params.has_alibi = false;
}
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream); run_mha_fwd(params, stream);
...@@ -640,6 +669,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -640,6 +669,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // batch_size x num_heads
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) { c10::optional<at::Tensor> &rng_state) {
...@@ -813,6 +843,19 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -813,6 +843,19 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
params.rng_state[1] = std::get<1>(seeds); params.rng_state[1] = std::get<1>(seeds);
} }
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
} else {
params.has_alibi = false;
}
if (seqlen_q > 0) { if (seqlen_q > 0) {
launch(params, stream, /*configure=*/false); launch(params, stream, /*configure=*/false);
} else { } else {
...@@ -856,6 +899,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -856,6 +899,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // b x num_heads
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) { c10::optional<at::Tensor> &rng_state) {
...@@ -1045,6 +1089,19 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1045,6 +1089,19 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
params.rng_state[1] = std::get<1>(seeds); params.rng_state[1] = std::get<1>(seeds);
} }
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
} else {
params.has_alibi = false;
}
launch(params, stream, /*configure=*/false); launch(params, stream, /*configure=*/false);
// For MQA/GQA we need to sum dK and dV across the groups // For MQA/GQA we need to sum dK and dV across the groups
...@@ -1077,7 +1134,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1077,7 +1134,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits int num_splits,
c10::optional<at::Tensor> &alibi_slopes_ // batch_size x num_heads
) { ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
...@@ -1121,7 +1179,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1121,7 +1179,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0; // TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together?
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !(alibi_slopes_.has_value());
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k; const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
...@@ -1283,6 +1342,19 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1283,6 +1342,19 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
params.oaccum_ptr = out_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr();
} }
if (alibi_slopes_.has_value()) {
auto alibi_slopes = alibi_slopes_.value();
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads);
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0);
} else {
params.has_alibi = false;
}
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
// Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value()); run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value());
......
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout>
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset_,
const int max_seqlen_q,
const int warp_row_stride,
const int head_idx,
const float softmax_scale,
const float alibi_slope) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int row_idx_offset = row_idx_offset_;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
const float alibi_slope_unscaled = alibi_slope / softmax_scale;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
const float alibi = alibi_slope_unscaled * col_idx;
if (col_idx < max_seqlen_k && row_idx < max_seqlen_q) {
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi;
}
}
}
}
}
}
} // namespace flash
\ No newline at end of file
...@@ -130,6 +130,13 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -130,6 +130,13 @@ struct Flash_fwd_params : public Qkv_params {
bool is_rotary_interleaved; bool is_rotary_interleaved;
int num_splits; // For split-KV version int num_splits; // For split-KV version
// float alibi_start;
// float alibi_ratio;
bool has_alibi;
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "utils.h" #include "utils.h"
#include "softmax.h" #include "softmax.h"
#include "alibi.h"
namespace flash { namespace flash {
using namespace cute; using namespace cute;
...@@ -422,7 +424,7 @@ inline __device__ void convert_dKV(const Params &params) { ...@@ -422,7 +424,7 @@ inline __device__ void convert_dKV(const Params &params) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) { inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -790,6 +792,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -790,6 +792,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
clear(acc_dv); clear(acc_dv);
clear(acc_dk); clear(acc_dk);
float alibi_slope = 0.0f;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
}
for (; m_block >= m_block_min; --m_block) { for (; m_block >= m_block_min; --m_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
clear(acc_s); clear(acc_s);
...@@ -813,6 +828,20 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -813,6 +828,20 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread(32, 0)) { print(scores); } // if (cute::thread(32, 0)) { print(scores); }
if (Has_alibi) {
flash::apply_alibi(
scores,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
AtomLayoutMS * 16,
bidh, params.scale_softmax,
alibi_slope
);
}
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// actual_seqlen_k, because acc_s would be some finite value for those indices. // actual_seqlen_k, because acc_s would be some finite value for those indices.
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0, // In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
...@@ -849,6 +878,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -849,6 +878,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
} }
} }
// if (cute::thread(32, 0)) { print(scores); } // if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value. // Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2); flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
...@@ -1114,7 +1144,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -1114,7 +1144,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_N, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) { inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -1373,6 +1403,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1373,6 +1403,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
clear(acc_dq); clear(acc_dq);
float alibi_slope = 0.0f;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
}
for (; n_block >= 0; --n_block) { for (; n_block >= 0; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M_SdP, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M_SdP, MMA_N)
clear(acc_s); clear(acc_s);
...@@ -1384,6 +1427,20 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1384,6 +1427,20 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) {
flash::apply_alibi(
scores,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
AtomLayoutMS * 16,
bidh, params.scale_softmax,
alibi_slope
);
}
// We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would // We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would
// be some finite value for those indices. In the end when we multiply with K to get dQ, // be some finite value for those indices. In the end when we multiply with K to get dQ,
// the corresponding values of K would be 0, so the result would still be correct. // the corresponding values of K would be 0, so the result would still be correct.
...@@ -1394,6 +1451,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1394,6 +1451,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
AtomLayoutMS * 16); AtomLayoutMS * 16);
} }
// Compute the exponential value. // Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2); flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
if (Is_dropout) { if (Is_dropout) {
...@@ -1536,7 +1594,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1536,7 +1594,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv(const Params &params) { inline __device__ void compute_dq_dk_dv(const Params &params) {
// The block index for the batch. // The block index for the batch.
...@@ -1550,20 +1608,20 @@ inline __device__ void compute_dq_dk_dv(const Params &params) { ...@@ -1550,20 +1608,20 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
if (n_block_max == 1) { if (n_block_max == 1) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0); compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0);
} else { } else {
// Iterating backward from n_block_max - 1 to 0 might save 1 register // Iterating backward from n_block_max - 1 to 0 might save 1 register
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1); compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1);
for (int n_block = n_block_max - 2; n_block > 0; n_block--) { for (int n_block = n_block_max - 2; n_block > 0; n_block--) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block); compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block);
} }
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0); compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0);
} }
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) { inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
const int n_block = blockIdx.x; const int n_block = blockIdx.x;
...@@ -1572,12 +1630,12 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) { ...@@ -1572,12 +1630,12 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.z; const int bidh = blockIdx.z;
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block); compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_N, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params &params) { inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
...@@ -1586,7 +1644,7 @@ inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params &params) { ...@@ -1586,7 +1644,7 @@ inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params &params) {
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.z; const int bidh = blockIdx.z;
compute_dq_dk_dv_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K>(params, bidb, bidh, m_block); compute_dq_dk_dv_1rowblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_N, Is_even_K>(params, bidb, bidh, m_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -18,20 +18,20 @@ __global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) { ...@@ -18,20 +18,20 @@ __global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) {
flash::clear_dKVaccum<Kernel_traits>(params); flash::clear_dKVaccum<Kernel_traits>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) { __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params); flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { __global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K>(params); flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_N, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) { __global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) {
flash::compute_dq_dk_dv_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K>(params); flash::compute_dq_dk_dv_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_N, Is_even_K>(params);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
...@@ -64,17 +64,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -64,17 +64,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>; // If Is_local, set Is_causal to false
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true>; auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal && !Is_local, Is_local, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
if (smem_size_dq_dk_dv >= 48 * 1024) { // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
C10_CUDA_CHECK(cudaFuncSetAttribute( if (smem_size_dq_dk_dv >= 48 * 1024) {
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); C10_CUDA_CHECK(cudaFuncSetAttribute(
} kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params); }
C10_CUDA_KERNEL_LAUNCH_CHECK(); kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}); });
}); });
}); });
...@@ -107,15 +109,17 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -107,15 +109,17 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst && IsEvenKConst, IsEvenKConst>; // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>; auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
if (smem_size_dq_dk_dv >= 48 * 1024) { // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
C10_CUDA_CHECK(cudaFuncSetAttribute( if (smem_size_dq_dk_dv >= 48 * 1024) {
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); C10_CUDA_CHECK(cudaFuncSetAttribute(
} kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params); }
C10_CUDA_KERNEL_LAUNCH_CHECK(); kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}); });
}); });
}); });
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "utils.h" #include "utils.h"
#include "softmax.h" #include "softmax.h"
#include "alibi.h"
namespace flash { namespace flash {
using namespace cute; using namespace cute;
...@@ -71,7 +73,7 @@ inline __device__ void write_softmax_to_gmem( ...@@ -71,7 +73,7 @@ inline __device__ void write_softmax_to_gmem(
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) { inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -326,6 +328,22 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -326,6 +328,22 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration. // We will have at least 1 "masking" iteration.
float alibi_slope = 0.0f;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local) constexpr int n_masking_steps = (!Is_causal && !Is_local)
...@@ -362,6 +380,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -362,6 +380,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// We don't put the masking before the matmul S = Q K^T because we don't clear sK // We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN. // can produce Inf / NaN.
if (Has_alibi) {
flash::apply_alibi(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope
);
}
if (!Is_causal && !Is_local) { if (!Is_causal && !Is_local) {
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else { } else {
...@@ -466,6 +498,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -466,6 +498,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) {
flash::apply_alibi(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope
);
}
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
flash::apply_mask_local( flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k, scores, n_block * kBlockN, binfo.actual_seqlen_k,
...@@ -474,6 +520,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -474,6 +520,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
params.window_size_left, params.window_size_right params.window_size_left, params.window_size_right
); );
} }
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores); Tensor rP = flash::convert_type<Element>(scores);
...@@ -581,7 +628,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -581,7 +628,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -909,6 +956,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -909,6 +956,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration. // We will have at least 1 "masking" iteration.
float alibi_slope = 0.0f;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local) constexpr int n_masking_steps = (!Is_causal && !Is_local)
...@@ -941,6 +1004,20 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -941,6 +1004,20 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) {
flash::apply_alibi(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope
);
}
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK // We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
...@@ -1020,6 +1097,20 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1020,6 +1097,20 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) {
flash::apply_alibi(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope
);
}
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
flash::apply_mask_local( flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k, scores, n_block * kBlockN, binfo.actual_seqlen_k,
...@@ -1131,7 +1222,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1131,7 +1222,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) { inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
...@@ -1147,12 +1238,12 @@ inline __device__ void compute_attn(const Params &params) { ...@@ -1147,12 +1238,12 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of // the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block); flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
inline __device__ void compute_attn_splitkv(const Params &params) { inline __device__ void compute_attn_splitkv(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
...@@ -1161,7 +1252,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) { ...@@ -1161,7 +1252,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
const int n_split_idx = Split ? blockIdx.y : 0; const int n_split_idx = Split ? blockIdx.y : 0;
const int num_n_splits = Split ? gridDim.y : 1; const int num_n_splits = Split ? gridDim.y : 1;
flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits); flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -10,15 +10,15 @@ ...@@ -10,15 +10,15 @@
#include "flash.h" #include "flash.h"
#include "flash_fwd_kernel.h" #include "flash_fwd_kernel.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) { __global__ void flash_fwd_kernel(Flash_fwd_params params) {
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params); flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
} }
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params); flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
} }
template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K> template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K>
...@@ -45,24 +45,26 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -45,24 +45,26 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
// Will only return softmax if dropout, to reduce compilation time. BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // Will only return softmax if dropout, to reduce compilation time.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>; // If Is_local, set Is_causal to false
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
if (smem_size >= 48 * 1024) { // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
C10_CUDA_CHECK(cudaFuncSetAttribute( if (smem_size >= 48 * 1024) {
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); C10_CUDA_CHECK(cudaFuncSetAttribute(
} kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// int ctas_per_sm; }
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // int ctas_per_sm;
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
C10_CUDA_KERNEL_LAUNCH_CHECK(); kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}); });
}); });
}); });
...@@ -84,18 +86,20 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -84,18 +86,20 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. BOOL_SWITCH(params.has_alibi, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If Is_local, set Is_causal to false // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>; // If Is_local, set Is_causal to false
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>; auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>; // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
if (smem_size >= 48 * 1024) { // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
C10_CUDA_CHECK(cudaFuncSetAttribute( if (smem_size >= 48 * 1024) {
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); C10_CUDA_CHECK(cudaFuncSetAttribute(
} kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); }
C10_CUDA_KERNEL_LAUNCH_CHECK(); kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}); });
}); });
}); });
......
...@@ -43,7 +43,7 @@ def _get_block_size(device, head_dim, is_dropout, is_causal): ...@@ -43,7 +43,7 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
return (128, 64) if is_sm80 else (64, 64) return (128, 64) if is_sm80 else (64, 64)
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax): def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
...@@ -56,6 +56,7 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, ...@@ -56,6 +56,7 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size,
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
return_softmax, return_softmax,
None, None,
) )
...@@ -74,6 +75,7 @@ def _flash_attn_varlen_forward( ...@@ -74,6 +75,7 @@ def _flash_attn_varlen_forward(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_softmax, return_softmax,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
...@@ -94,6 +96,7 @@ def _flash_attn_varlen_forward( ...@@ -94,6 +96,7 @@ def _flash_attn_varlen_forward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
return_softmax, return_softmax,
None, None,
) )
...@@ -116,6 +119,7 @@ def _flash_attn_backward( ...@@ -116,6 +119,7 @@ def _flash_attn_backward(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
rng_state=None, rng_state=None,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
...@@ -136,6 +140,7 @@ def _flash_attn_backward( ...@@ -136,6 +140,7 @@ def _flash_attn_backward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
None, None,
rng_state, rng_state,
) )
...@@ -160,6 +165,7 @@ def _flash_attn_varlen_backward( ...@@ -160,6 +165,7 @@ def _flash_attn_varlen_backward(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
rng_state=None, rng_state=None,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
...@@ -185,6 +191,7 @@ def _flash_attn_varlen_backward( ...@@ -185,6 +191,7 @@ def _flash_attn_varlen_backward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
None, None,
rng_state, rng_state,
) )
...@@ -195,7 +202,7 @@ def _flash_attn_varlen_backward( ...@@ -195,7 +202,7 @@ def _flash_attn_varlen_backward(
class FlashAttnQKVPackedFunc(torch.autograd.Function): class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, return_softmax): def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5) softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
...@@ -206,6 +213,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -206,6 +213,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
...@@ -213,6 +221,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -213,6 +221,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -234,10 +243,11 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): ...@@ -234,10 +243,11 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes,
rng_state=rng_state, rng_state=rng_state,
) )
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None return dqkv, None, None, None, None, None, None
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...@@ -251,6 +261,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -251,6 +261,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_softmax, return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
...@@ -267,6 +278,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -267,6 +278,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
...@@ -275,6 +287,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -275,6 +287,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -300,15 +313,16 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): ...@@ -300,15 +313,16 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes,
rng_state=rng_state, rng_state=rng_state,
) )
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None return dqkv, None, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, return_softmax): def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
...@@ -319,6 +333,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -319,6 +333,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
...@@ -326,6 +341,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -326,6 +341,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -348,11 +364,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -348,11 +364,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None return dq, dkv, None, None, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...@@ -369,6 +386,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -369,6 +386,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_softmax, return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
...@@ -385,6 +403,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -385,6 +403,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward( ctx.save_for_backward(
...@@ -396,6 +415,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -396,6 +415,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -422,16 +442,17 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): ...@@ -422,16 +442,17 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]] dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None, None return dq, dkv, None, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, return_softmax): def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
...@@ -442,6 +463,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -442,6 +463,7 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
...@@ -449,6 +471,7 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -449,6 +471,7 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -469,12 +492,13 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -469,12 +492,13 @@ class FlashAttnFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function): class FlashAttnVarlenFunc(torch.autograd.Function):
...@@ -492,6 +516,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -492,6 +516,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_softmax, return_softmax,
): ):
if softmax_scale is None: if softmax_scale is None:
...@@ -508,6 +533,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -508,6 +533,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
softmax_scale, softmax_scale,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
) )
ctx.save_for_backward( ctx.save_for_backward(
...@@ -519,6 +545,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -519,6 +545,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale ctx.softmax_scale = softmax_scale
ctx.causal = causal ctx.causal = causal
ctx.window_size = window_size ctx.window_size = window_size
ctx.alibi_slopes = alibi_slopes
return out if not return_softmax else (out, softmax_lse, S_dmask) return out if not return_softmax else (out, softmax_lse, S_dmask)
@staticmethod @staticmethod
...@@ -543,12 +570,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ...@@ -543,12 +570,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale, ctx.softmax_scale,
ctx.causal, ctx.causal,
ctx.window_size, ctx.window_size,
ctx.alibi_slopes,
rng_state=rng_state, rng_state=rng_state,
) )
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func( def flash_attn_qkvpacked_func(
...@@ -557,6 +585,7 @@ def flash_attn_qkvpacked_func( ...@@ -557,6 +585,7 @@ def flash_attn_qkvpacked_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -589,7 +618,7 @@ def flash_attn_qkvpacked_func( ...@@ -589,7 +618,7 @@ def flash_attn_qkvpacked_func(
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnQKVPackedFunc.apply( return FlashAttnQKVPackedFunc.apply(
qkv, dropout_p, softmax_scale, causal, window_size, return_attn_probs qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
) )
...@@ -600,6 +629,7 @@ def flash_attn_kvpacked_func( ...@@ -600,6 +629,7 @@ def flash_attn_kvpacked_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -648,7 +678,7 @@ def flash_attn_kvpacked_func( ...@@ -648,7 +678,7 @@ def flash_attn_kvpacked_func(
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnKVPackedFunc.apply( return FlashAttnKVPackedFunc.apply(
q, kv, dropout_p, softmax_scale, causal, window_size, return_attn_probs q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
) )
...@@ -660,6 +690,7 @@ def flash_attn_func( ...@@ -660,6 +690,7 @@ def flash_attn_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -706,7 +737,7 @@ def flash_attn_func( ...@@ -706,7 +737,7 @@ def flash_attn_func(
pattern (negative means that location was dropped, nonnegative means it was kept). pattern (negative means that location was dropped, nonnegative means it was kept).
""" """
return FlashAttnFunc.apply( return FlashAttnFunc.apply(
q, k, v, dropout_p, softmax_scale, causal, window_size, return_attn_probs q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_attn_probs
) )
...@@ -718,6 +749,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -718,6 +749,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -760,6 +792,7 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -760,6 +792,7 @@ def flash_attn_varlen_qkvpacked_func(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_attn_probs, return_attn_probs,
) )
...@@ -775,6 +808,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -775,6 +808,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -839,6 +873,7 @@ def flash_attn_varlen_kvpacked_func( ...@@ -839,6 +873,7 @@ def flash_attn_varlen_kvpacked_func(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_attn_probs, return_attn_probs,
) )
...@@ -855,6 +890,7 @@ def flash_attn_varlen_func( ...@@ -855,6 +890,7 @@ def flash_attn_varlen_func(
softmax_scale=None, softmax_scale=None,
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
alibi_slopes=None,
return_attn_probs=False, return_attn_probs=False,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
...@@ -918,6 +954,7 @@ def flash_attn_varlen_func( ...@@ -918,6 +954,7 @@ def flash_attn_varlen_func(
softmax_scale, softmax_scale,
causal, causal,
window_size, window_size,
alibi_slopes,
return_attn_probs, return_attn_probs,
) )
...@@ -937,6 +974,7 @@ def flash_attn_with_kvcache( ...@@ -937,6 +974,7 @@ def flash_attn_with_kvcache(
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
rotary_interleaved=True, rotary_interleaved=True,
num_splits=0, num_splits=0,
alibi_slopes=None,
): ):
""" """
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...@@ -1041,5 +1079,6 @@ def flash_attn_with_kvcache( ...@@ -1041,5 +1079,6 @@ def flash_attn_with_kvcache(
window_size[1], window_size[1],
rotary_interleaved, rotary_interleaved,
num_splits, num_splits,
alibi_slopes
) )
return out return out
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment