Unverified Commit ee71ed8a authored by PGFLMG's avatar PGFLMG Committed by GitHub
Browse files

[Feat] QWen-1M context support[1/2]: Update block sparse attention backend utils kernel (#5847)


Co-authored-by: default avatarsighingnow <sighingnow@gmail.com>
parent d364b9b0
...@@ -176,6 +176,7 @@ set(SOURCES ...@@ -176,6 +176,7 @@ set(SOURCES
"csrc/attention/cascade.cu" "csrc/attention/cascade.cu"
"csrc/attention/merge_attn_states.cu" "csrc/attention/merge_attn_states.cu"
"csrc/attention/cutlass_mla_kernel.cu" "csrc/attention/cutlass_mla_kernel.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/attention/lightning_attention_decode_kernel.cu" "csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/elementwise/activation.cu" "csrc/elementwise/activation.cu"
"csrc/elementwise/fused_add_rms_norm_kernel.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu"
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// This file is for blocksparse attention utils cuda kernel.
#include <assert.h>
#include <cuda.h>
#include <torch/all.h>
// Save the start index of each block in the given range into block_offset.
// Returns the updated block count.
__device__ int64_t save_blocks(
int* block_offset,
int64_t range_start,
int64_t range_end,
int64_t block_size,
int64_t input_block_count,
int64_t kv_seqlen) {
if (range_start >= kv_seqlen) {
return input_block_count;
}
if (range_end > kv_seqlen) {
range_end = kv_seqlen;
}
int64_t current_block_count = input_block_count;
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[current_block_count++] = idx;
}
return current_block_count;
}
// CUDA kernel: convert sparse vertical/slash indices to block/column offsets.
__global__ void convert_vertical_slash_indexes_kernel(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS,
int64_t N_ROWS,
int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N,
int64_t NNZ_V,
int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) || (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
// Host function: launches the kernel with 64 threads per block.
void convert_vertical_slash_indexes_64x64(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE,
int64_t N_HEADS,
int64_t N_ROWS,
int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N,
int64_t NNZ_V,
int64_t NNZ_S,
bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock((int32_t)N_THREADS);
const dim3 dimGrid(
(int32_t)N_HEADS, (int32_t)BATCH_SIZE, ((int32_t)N_ROWS + (int32_t)N_THREADS - 1) / (int32_t)N_THREADS);
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
block_count,
block_offset,
column_count,
column_index,
N_HEADS,
N_ROWS,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
NNZ_V,
NNZ_S,
causal);
}
// Host function: prepares tensor pointers and launches the CUDA kernel.
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size,
int64_t block_size_M,
int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int64_t batch_size = slash_indexes.size(0);
int64_t num_heads = slash_indexes.size(1);
int64_t nnz_slash = slash_indexes.size(2);
int64_t nnz_vertical = vertical_indexes.size(2);
int64_t num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64(
q_seqlens.data_ptr<int>(),
kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
block_size_M,
block_size_N,
nnz_vertical,
nnz_slash,
causal);
}
// --- mergehead kernels --- //
// Kernel: like above, but supports per-head variable NNZ_V/NNZ_S.
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
const int* per_head_vertical_topkv,
const int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS,
int64_t N_ROWS,
int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N,
int64_t NNZ_V,
int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
// MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S
// above is buffer size, use to compute offset)
NNZ_S = per_head_slash_topkv[head_idx];
NNZ_V = per_head_vertical_topkv[head_idx];
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) || (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
// Launch the mergehead kernel with 64 threads per block.
void convert_vertical_slash_indexes_64x64_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* per_head_vertical_topkv,
int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE,
int64_t N_HEADS,
int64_t N_ROWS,
int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N,
int64_t NNZ_V,
int64_t NNZ_S,
bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
per_head_vertical_topkv,
per_head_slash_topkv,
block_count,
block_offset,
column_count,
column_index,
N_HEADS,
N_ROWS,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
NNZ_V,
NNZ_S,
causal);
}
// Host wrapper for mergehead kernel.
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count,
int64_t context_size,
int64_t block_size_M,
int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64_mergehead(
q_seqlens.data_ptr<int>(),
kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
vertical_indices_count.data_ptr<int>(),
slash_indices_count.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
block_size_M,
block_size_N,
nnz_vertical,
nnz_slash,
causal);
}
...@@ -234,6 +234,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -234,6 +234,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"Generator? gen) -> Tensor[]"); "Generator? gen) -> Tensor[]");
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse); m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);
// Sparse Attention utils
m.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
m.impl("convert_vertical_slash_indexes", torch::kCUDA, &convert_vertical_slash_indexes);
m.def(
"convert_vertical_slash_indexes_mergehead("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" Tensor vertical_indices_count, Tensor slash_indices_count, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
m.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, &convert_vertical_slash_indexes_mergehead);
/* /*
* From XGrammar * From XGrammar
*/ */
......
...@@ -353,6 +353,36 @@ std::vector<at::Tensor> mha_varlen_fwd_sparse( ...@@ -353,6 +353,36 @@ std::vector<at::Tensor> mha_varlen_fwd_sparse(
c10::optional<at::Generator> gen_); c10::optional<at::Generator> gen_);
} // namespace flash } // namespace flash
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size,
int64_t block_size_M,
int64_t block_size_N,
bool causal);
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count,
int64_t context_size,
int64_t block_size_M,
int64_t block_size_N,
bool causal);
/* /*
* From XGrammar * From XGrammar
*/ */
......
...@@ -8,6 +8,124 @@ def maybe_contiguous(x): ...@@ -8,6 +8,124 @@ def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x return x.contiguous() if x is not None and x.stride(-1) != 1 else x
# Sparse attention utils
def convert_vertical_slash_indexes(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.zeros(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
block_offset = torch.zeros(
batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
column_count = torch.zeros(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
column_index = torch.zeros(
batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
torch.ops.sgl_kernel.convert_vertical_slash_indexes.default(
block_count,
block_offset,
column_count,
column_index,
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
context_size,
block_size_M,
block_size_N,
causal,
)
return block_count, block_offset, column_count, column_index
def convert_vertical_slash_indexes_mergehead(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
# [N_HEADS] : different head use different number of indices
vertical_indices_count: torch.Tensor,
slash_indices_count: torch.Tensor,
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.empty(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
block_offset = torch.empty(
batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
column_count = torch.empty(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
column_index = torch.empty(
batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
torch.ops.sgl_kernel.convert_vertical_slash_indexes_mergehead.default(
block_count,
block_offset,
column_count,
column_index,
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
vertical_indices_count,
slash_indices_count,
context_size,
block_size_M,
block_size_N,
causal,
)
return block_count, block_offset, column_count, column_index
def sparse_attn_func( def sparse_attn_func(
q, q,
k, k,
......
...@@ -4,7 +4,12 @@ from typing import List, Optional, Tuple ...@@ -4,7 +4,12 @@ from typing import List, Optional, Tuple
import pytest import pytest
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
from sgl_kernel.sparse_flash_attn import sparse_attn_func, sparse_attn_varlen_func from sgl_kernel.sparse_flash_attn import (
convert_vertical_slash_indexes,
convert_vertical_slash_indexes_mergehead,
sparse_attn_func,
sparse_attn_varlen_func,
)
def ref_attn( def ref_attn(
...@@ -249,6 +254,133 @@ def test_sparse_attention( ...@@ -249,6 +254,133 @@ def test_sparse_attention(
), f"{torch.max(torch.abs(lse - ref_lse))}" ), f"{torch.max(torch.abs(lse - ref_lse))}"
# sparse attention utils
# origin
@pytest.mark.parametrize("causal", [True, False])
def test_convert_vertical_slash_indexes(causal):
# Prepare small, hand-checkable inputs
q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") # [BATCH]
kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
vertical_indexes = torch.tensor(
[[[1, 3]]], dtype=torch.int32, device="cuda"
) # [BATCH, N_HEADS, NNZ_V]
slash_indexes = torch.tensor(
[[[2]]], dtype=torch.int32, device="cuda"
) # [BATCH, N_HEADS, NNZ_S]
context_size = 4
block_size_M = 2
block_size_N = 2
# Call your CUDA kernel wrapper
block_count, block_offset, column_count, column_index = (
convert_vertical_slash_indexes(
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
context_size,
block_size_M,
block_size_N,
causal=causal,
)
)
# Manually create expected outputs for this input
# There are 2 rows (blocks): row0 (tokens 0-1), row1 (tokens 2-3)
# Fill these expected tensors based on your CUDA kernel's logic
# For demonstration, we assume:
# - block_count: how many slash indices fall into each block
# - block_offset: the value of those indices
# - column_count: number of valid vertical indices per block
# - column_index: the actual vertical indices
expected_column_index = torch.tensor(
[[[[0, 0], [0, 0]]]], dtype=torch.int32, device="cuda"
)
# If causal=False, update these tensors according to expected behavior
if not causal:
# Update these values if your kernel produces different output in non-causal mode
expected_column_index = torch.tensor(
[[[[1, 0], [1, 3]]]], dtype=torch.int32, device="cuda"
)
# Assert that outputs match expectations
assert torch.equal(column_index, expected_column_index)
# mergehead
@pytest.mark.parametrize("causal", [True, False])
def test_convert_vertical_slash_indexes_mergehead(causal):
# Prepare small, hand-checkable inputs for mergehead version
q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
vertical_indexes = torch.tensor(
[
[
[1, 3], # head 0
[2, 0], # head 1
]
],
dtype=torch.int32,
device="cuda",
) # [BATCH, N_HEADS, NNZ_V]
slash_indexes = torch.tensor(
[
[
[2, 0], # head 0
[1, 3], # head 1
]
],
dtype=torch.int32,
device="cuda",
) # [BATCH, N_HEADS, NNZ_S]
vertical_indices_count = torch.tensor([2, 1], dtype=torch.int32, device="cuda")
slash_indices_count = torch.tensor([1, 2], dtype=torch.int32, device="cuda")
context_size = 4
block_size_M = 2
block_size_N = 2
# Call your CUDA kernel wrapper
block_count, block_offset, column_count, column_index = (
convert_vertical_slash_indexes_mergehead(
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
vertical_indices_count,
slash_indices_count,
context_size,
block_size_M,
block_size_N,
causal=causal,
)
)
# Manually create expected outputs for this input
# For demonstration, assume:
# - batch=1, head=2, num_rows=2, nnz_v=2, nnz_s=2
# Fill these expected tensors according to your kernel's behavior
expected_column_index = torch.tensor(
[[[[1, 0], [1, 3]], [[-1079459945, -1077788999], [-1080050043, -1104625879]]]],
dtype=torch.int32,
device="cuda",
)
if not causal:
# If non-causal mode output is different, update these values
expected_column_index = torch.tensor(
[[[[1, 0], [1, 3]], [[2, -1077788999], [2, -1104625879]]]],
dtype=torch.int32,
device="cuda",
)
# Assert that outputs match expectations
assert torch.equal(column_index, expected_column_index)
# skip cause use fa2 for test
# @pytest.mark.parametrize("seq_lens", [[(1024, 1328)], # @pytest.mark.parametrize("seq_lens", [[(1024, 1328)],
# [(1024, 1328), (1, 2048)], # [(1024, 1328), (1, 2048)],
# [(1025, 1328), (2, 2048)], # [(1025, 1328), (2, 2048)],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment