Commit 5ab9b366 authored by Tri Dao's avatar Tri Dao
Browse files

Clean up alibi, implement non-causal alibi

parent bc28eacc
...@@ -253,12 +253,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -253,12 +253,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, const float p_dropout,
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // batch_size x num_heads
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
...@@ -297,13 +297,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -297,13 +297,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case // causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
if (is_causal) { window_size_right = 0; } if (is_causal) { window_size_right = 0; }
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
// TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together? const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !(alibi_slopes_.has_value());
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k; const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
...@@ -416,12 +416,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size ...@@ -416,12 +416,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes); CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
} else { } else {
params.has_alibi = false; params.alibi_slopes_ptr = nullptr;
} }
if (seqlen_k > 0) { if (seqlen_k > 0) {
...@@ -456,6 +455,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -456,6 +455,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q, const int max_seqlen_q,
const int max_seqlen_k, const int max_seqlen_k,
const float p_dropout, const float p_dropout,
...@@ -464,7 +464,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -464,7 +464,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // b x num_heads
const bool return_softmax, const bool return_softmax,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
...@@ -612,12 +611,11 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -612,12 +611,11 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes); CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
} else { } else {
params.has_alibi = false; params.alibi_slopes_ptr = nullptr;
} }
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
...@@ -664,12 +662,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -664,12 +662,12 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
const float softmax_scale, const float softmax_scale,
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // batch_size x num_heads
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) { c10::optional<at::Tensor> &rng_state) {
...@@ -848,12 +846,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si ...@@ -848,12 +846,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes); CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
} else { } else {
params.has_alibi = false; params.alibi_slopes_ptr = nullptr;
} }
if (seqlen_q > 0) { if (seqlen_q > 0) {
...@@ -891,6 +888,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -891,6 +888,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q, const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop const float p_dropout, // probability to drop
...@@ -899,7 +897,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -899,7 +897,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const bool is_causal, const bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
c10::optional<at::Tensor> &alibi_slopes_, // b x num_heads
c10::optional<at::Generator> gen_, c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) { c10::optional<at::Tensor> &rng_state) {
...@@ -1094,12 +1091,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -1094,12 +1091,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes); CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
} else { } else {
params.has_alibi = false; params.alibi_slopes_ptr = nullptr;
} }
launch(params, stream, /*configure=*/false); launch(params, stream, /*configure=*/false);
...@@ -1128,14 +1124,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1128,14 +1124,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2) c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2) c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
const int window_size_left, const int window_size_left,
int window_size_right, int window_size_right,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits, int num_splits
c10::optional<at::Tensor> &alibi_slopes_ // batch_size x num_heads
) { ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
...@@ -1174,13 +1170,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1174,13 +1170,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case // causal=true is the same as causal=false in this case
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
if (is_causal) { window_size_right = 0; } if (is_causal) { window_size_right = 0; }
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza // H/t Daniel Haziza
// TODO: how to make "seqlenq_ngroups_swapped" and ALiBi work together? const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !(alibi_slopes_.has_value());
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k; const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
...@@ -1347,12 +1343,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ...@@ -1347,12 +1343,11 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
CHECK_DEVICE(alibi_slopes); CHECK_DEVICE(alibi_slopes);
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
CHECK_SHAPE(alibi_slopes, batch_size, num_heads); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
params.has_alibi = true;
params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_ptr = alibi_slopes.data_ptr();
params.alibi_slopes_batch_stride = alibi_slopes.stride(0); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
} else { } else {
params.has_alibi = false; params.alibi_slopes_ptr = nullptr;
} }
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
......
...@@ -13,37 +13,45 @@ using namespace cute; ...@@ -13,37 +13,45 @@ using namespace cute;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout> template <bool Is_causal, typename Engine, typename Layout>
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor, inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_, const int col_idx_offset_,
const int max_seqlen_k, const int max_seqlen_k,
const int row_idx_offset_, const int row_idx_offset,
const int max_seqlen_q, const int max_seqlen_q,
const int warp_row_stride, const int warp_row_stride,
const int head_idx,
const float softmax_scale,
const float alibi_slope) { const float alibi_slope) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
const int row_idx_offset = row_idx_offset_;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
const float alibi_slope_unscaled = alibi_slope / softmax_scale; if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) { for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int row_idx = row_idx_base + i * 8; const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll #pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx_base = col_idx_offset + nj * 8; const int col_idx = col_idx_base + j;
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int col_idx = col_idx_base + j; tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
const float alibi = alibi_slope_unscaled * col_idx; }
if (col_idx < max_seqlen_k && row_idx < max_seqlen_q) { }
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi; }
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
} }
} }
} }
...@@ -51,4 +59,4 @@ inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor, ...@@ -51,4 +59,4 @@ inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
} }
} }
} // namespace flash } // namespace flash
\ No newline at end of file
...@@ -131,10 +131,6 @@ struct Flash_fwd_params : public Qkv_params { ...@@ -131,10 +131,6 @@ struct Flash_fwd_params : public Qkv_params {
int num_splits; // For split-KV version int num_splits; // For split-KV version
// float alibi_start;
// float alibi_ratio;
bool has_alibi;
void * __restrict__ alibi_slopes_ptr; void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride; index_t alibi_slopes_batch_stride;
}; };
......
...@@ -753,8 +753,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -753,8 +753,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
#pragma unroll #pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi)); const int row = get<0>(taccScS_row(mi));
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
} }
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
// Tensor tKrK = make_fragment_like(tKsK); // Tensor tKrK = make_fragment_like(tKsK);
// // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK); // // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
...@@ -792,18 +796,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -792,18 +796,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
clear(acc_dv); clear(acc_dv);
clear(acc_dk); clear(acc_dk);
float alibi_slope = 0.0f; float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
}
for (; m_block >= m_block_min; --m_block) { for (; m_block >= m_block_min; --m_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
...@@ -830,18 +823,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -830,18 +823,17 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread(32, 0)) { print(scores); } // if (cute::thread(32, 0)) { print(scores); }
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)), m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
AtomLayoutMS * 16, AtomLayoutMS * 16,
bidh, params.scale_softmax,
alibi_slope alibi_slope
); );
} }
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// actual_seqlen_k, because acc_s would be some finite value for those indices. // actual_seqlen_k, because acc_s would be some finite value for those indices.
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0, // In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
...@@ -1403,18 +1395,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1403,18 +1395,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
clear(acc_dq); clear(acc_dq);
float alibi_slope = 0.0f; float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
}
for (; n_block >= 0; --n_block) { for (; n_block >= 0; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M_SdP, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M_SdP, MMA_N)
...@@ -1429,14 +1410,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in ...@@ -1429,14 +1410,13 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)), m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
AtomLayoutMS * 16, AtomLayoutMS * 16,
bidh, params.scale_softmax,
alibi_slope alibi_slope
); );
} }
......
...@@ -64,12 +64,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -64,12 +64,11 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
BOOL_SWITCH(params.has_alibi, Has_alibi, [&] { BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false // If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal && !Is_local, Is_local, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>; auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) { if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute( C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
...@@ -109,7 +108,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -109,7 +108,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(params.has_alibi, Has_alibi, [&] { BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, IsEvenNConst && IsEvenKConst, IsEvenKConst>; auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Has_alibi, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
......
...@@ -322,28 +322,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -322,28 +322,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
clear(acc_o); clear(acc_o);
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
// For performance reason, we separate out two kinds of iterations: // For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't. // those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration. // We will have at least 1 "masking" iteration.
float alibi_slope = 0.0f;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local) constexpr int n_masking_steps = (!Is_causal && !Is_local)
...@@ -382,14 +368,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -382,14 +368,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// can produce Inf / NaN. // can produce Inf / NaN.
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN, n_block * kBlockN,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
kNWarps * 16, kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope alibi_slope
); );
} }
...@@ -500,14 +485,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi ...@@ -500,14 +485,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN, n_block * kBlockN,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
kNWarps * 16, kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope alibi_slope
); );
} }
...@@ -950,28 +934,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -950,28 +934,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
clear(acc_o); clear(acc_o);
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
// For performance reason, we separate out two kinds of iterations: // For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't. // those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration. // We will have at least 1 "masking" iteration.
float alibi_slope = 0.0f;
if (Has_alibi) {
Tensor gAS = make_tensor(
make_gmem_ptr(
reinterpret_cast<ElementAccum *>(params.alibi_slopes_ptr)
+ bidb * params.alibi_slopes_batch_stride + bidh
),
Shape<_1>{});
Tensor rAS = make_fragment_like(gAS);
cute::copy(gAS, rAS);
alibi_slope = rAS(0);
// if (m_block == 0 && tidx == 0) {
// printf("%d,%d,%f\n", bidb, bidh, alibi_slope);
// }
}
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local) constexpr int n_masking_steps = (!Is_causal && !Is_local)
...@@ -1006,14 +976,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1006,14 +976,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN, n_block * kBlockN,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
kNWarps * 16, kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope alibi_slope
); );
} }
...@@ -1099,14 +1068,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -1099,14 +1068,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi( flash::apply_alibi<Is_causal>(
scores, scores,
n_block * kBlockN, n_block * kBlockN,
binfo.actual_seqlen_k, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, binfo.actual_seqlen_q,
kNWarps * 16, kNWarps * 16,
bidh, params.scale_softmax,
alibi_slope alibi_slope
); );
} }
......
...@@ -45,7 +45,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -45,7 +45,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
BOOL_SWITCH(params.has_alibi, Has_alibi, [&] { BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// Will only return softmax if dropout, to reduce compilation time. // Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates // If return_softmax, set IsEvenMNConst to false to reduce number of templates
...@@ -86,7 +86,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) { ...@@ -86,7 +86,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
BOOL_SWITCH(params.has_alibi, Has_alibi, [&] { BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false // If Is_local, set Is_causal to false
......
...@@ -141,14 +141,12 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_ ...@@ -141,14 +141,12 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
template <bool HasWSLeft=true, typename Engine, typename Layout> template <bool HasWSLeft=true, typename Engine, typename Layout>
inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_, inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride, const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) { const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
// const int row_idx_offset = row_idx_offset_ + lane_id / 4;
const int row_idx_offset = row_idx_offset_;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
...@@ -180,17 +178,17 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in ...@@ -180,17 +178,17 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_, inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset_, const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride) { const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
max_seqlen_q, warp_row_stride, -1, 0); max_seqlen_q, warp_row_stride, -1, 0);
} }
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx( inline __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol, Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_) const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
{ {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
...@@ -199,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx( ...@@ -199,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) { for (int mi = 0; mi < size<0>(tensor); ++mi) {
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
......
...@@ -53,12 +53,12 @@ def _flash_attn_forward( ...@@ -53,12 +53,12 @@ def _flash_attn_forward(
k, k,
v, v,
None, None,
alibi_slopes,
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
return_softmax, return_softmax,
None, None,
) )
...@@ -90,6 +90,7 @@ def _flash_attn_varlen_forward( ...@@ -90,6 +90,7 @@ def _flash_attn_varlen_forward(
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
None, None,
alibi_slopes,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
dropout_p, dropout_p,
...@@ -98,7 +99,6 @@ def _flash_attn_varlen_forward( ...@@ -98,7 +99,6 @@ def _flash_attn_varlen_forward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
return_softmax, return_softmax,
None, None,
) )
...@@ -137,12 +137,12 @@ def _flash_attn_backward( ...@@ -137,12 +137,12 @@ def _flash_attn_backward(
dq, dq,
dk, dk,
dv, dv,
alibi_slopes,
dropout_p, dropout_p,
softmax_scale, softmax_scale,
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
None, None,
rng_state, rng_state,
) )
...@@ -185,6 +185,7 @@ def _flash_attn_varlen_backward( ...@@ -185,6 +185,7 @@ def _flash_attn_varlen_backward(
dv, dv,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
alibi_slopes,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
dropout_p, dropout_p,
...@@ -193,7 +194,6 @@ def _flash_attn_varlen_backward( ...@@ -193,7 +194,6 @@ def _flash_attn_varlen_backward(
causal, causal,
window_size[0], window_size[0],
window_size[1], window_size[1],
alibi_slopes,
None, None,
rng_state, rng_state,
) )
...@@ -613,6 +613,8 @@ def flash_attn_qkvpacked_func( ...@@ -613,6 +613,8 @@ def flash_attn_qkvpacked_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -673,6 +675,9 @@ def flash_attn_kvpacked_func( ...@@ -673,6 +675,9 @@ def flash_attn_kvpacked_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -732,6 +737,9 @@ def flash_attn_func( ...@@ -732,6 +737,9 @@ def flash_attn_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -780,6 +788,8 @@ def flash_attn_varlen_qkvpacked_func( ...@@ -780,6 +788,8 @@ def flash_attn_varlen_qkvpacked_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -858,6 +868,9 @@ def flash_attn_varlen_kvpacked_func( ...@@ -858,6 +868,9 @@ def flash_attn_varlen_kvpacked_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -938,6 +951,9 @@ def flash_attn_varlen_func( ...@@ -938,6 +951,9 @@ def flash_attn_varlen_func(
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling). (they might not have the right scaling).
...@@ -981,8 +997,8 @@ def flash_attn_with_kvcache( ...@@ -981,8 +997,8 @@ def flash_attn_with_kvcache(
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
rotary_interleaved=True, rotary_interleaved=True,
num_splits=0,
alibi_slopes=None, alibi_slopes=None,
num_splits=0,
): ):
""" """
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...@@ -1050,6 +1066,9 @@ def flash_attn_with_kvcache( ...@@ -1050,6 +1066,9 @@ def flash_attn_with_kvcache(
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style). (i.e. GPT-NeoX style).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
num_splits: int. If > 1, split the key/value into this many chunks along the sequence. num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits. to automatically determine the number of splits.
...@@ -1080,6 +1099,7 @@ def flash_attn_with_kvcache( ...@@ -1080,6 +1099,7 @@ def flash_attn_with_kvcache(
rotary_cos, rotary_cos,
rotary_sin, rotary_sin,
cache_batch_idx, cache_batch_idx,
alibi_slopes,
None, None,
softmax_scale, softmax_scale,
causal, causal,
...@@ -1087,6 +1107,5 @@ def flash_attn_with_kvcache( ...@@ -1087,6 +1107,5 @@ def flash_attn_with_kvcache(
window_size[1], window_size[1],
rotary_interleaved, rotary_interleaved,
num_splits, num_splits,
alibi_slopes,
) )
return out return out
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment