Unverified Commit 60f76243 authored by Tao He's avatar Tao He Committed by GitHub
Browse files

Implements dual-chunk-flash-attn backend for dual chunk attention with sparse...

Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support (#11844)
parent f6518b2b
......@@ -230,6 +230,7 @@ set(VLLM_EXT_SRC
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <assert.h>
#include <cuda.h>
#include <torch/all.h>
__device__ int64_t save_blocks(int* block_offset, int64_t range_start,
int64_t range_end, int64_t block_size,
int64_t input_block_count, int64_t kv_seqlen) {
if (range_start >= kv_seqlen) {
return input_block_count;
}
if (range_end > kv_seqlen) {
range_end = kv_seqlen;
}
int64_t current_block_count = input_block_count;
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[current_block_count++] = idx;
}
return current_block_count;
}
__global__ void convert_vertical_slash_indexes_kernel(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t NNZ_V, int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) ||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count,
block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M,
BLOCK_SIZE_N, NNZ_V, NNZ_S, causal);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* This function builds the index of each row of blocks from vertical indices
* and slash indices. The vertical indices are treated as points, while the
* slash indices are converted as ranges. The output consists of the merged
* ranges and separate column indices, where the ranges are represented by
* block indices.
*
* The implementation is referenced from the original MInference repo:
* https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu.
*/
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64(
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(), block_offset.data_ptr<int>(),
column_count.data_ptr<int>(), column_index.data_ptr<int>(), batch_size,
num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash,
causal);
}
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
const int* per_head_vertical_topkv, const int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t NNZ_V, int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
// MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S
// above is buffer size, use to compute offset)
NNZ_S = per_head_slash_topkv[head_idx];
NNZ_V = per_head_vertical_topkv[head_idx];
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) ||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* per_head_vertical_topkv, int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes,
per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset,
column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
NNZ_V, NNZ_S, causal);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* Like the above convert_vertical_slash_indexes, but with
* pre-computed vertical and slash counts.
*/
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, // [N_HEADS, ]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64_mergehead(
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
vertical_indices_count.data_ptr<int>(),
slash_indices_count.data_ptr<int>(), block_count.data_ptr<int>(),
block_offset.data_ptr<int>(), column_count.data_ptr<int>(),
column_index.data_ptr<int>(), batch_size, num_heads, num_rows,
block_size_M, block_size_N, nnz_vertical, nnz_slash, causal);
}
......@@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal);
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, int64_t context_size,
int64_t block_size_M, int64_t block_size_N, bool causal);
#endif
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
......
......@@ -77,6 +77,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes", torch::kCUDA,
&convert_vertical_slash_indexes);
ops.def(
"convert_vertical_slash_indexes_mergehead("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" Tensor vertical_indices_count, Tensor slash_indices_count, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA,
&convert_vertical_slash_indexes_mergehead);
#endif
// Activation ops
......
# SPDX-License-Identifier: Apache-2.0
import os
from urllib.request import urlopen
from vllm import LLM, SamplingParams
os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN"
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
def load_prompt() -> str:
# Test cases with various lengths can be found at:
#
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
with urlopen(
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
"/Qwen2.5-1M/test-data/600k.txt",
timeout=5) as response:
prompt = response.read().decode('utf-8')
return prompt
# Processing the prompt.
def process_requests(llm: LLM, prompts: list[str]) -> None:
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.8,
top_k=20,
repetition_penalty=1.05,
detokenize=True,
max_tokens=256,
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt_token_ids = output.prompt_token_ids
generated_text = output.outputs[0].text
print(f"Prompt length: {len(prompt_token_ids)}, "
f"Generated text: {generated_text!r}")
# Create an LLM.
def initialize_engine() -> LLM:
llm = LLM(model="Qwen/Qwen2.5-7B-Instruct-1M",
max_model_len=1048576,
tensor_parallel_size=4,
enforce_eager=True,
enable_chunked_prefill=True,
max_num_batched_tokens=131072)
return llm
def main():
llm = initialize_engine()
prompt = load_prompt()
process_requests(llm, [prompt])
if __name__ == '__main__':
main()
......@@ -150,6 +150,101 @@ def merge_attn_states(output: torch.Tensor,
prefix_lse, suffix_output, suffix_lse)
def convert_vertical_slash_indexes(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.zeros(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
block_offset = torch.zeros(batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_count = torch.zeros(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_index = torch.zeros(batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
torch.ops._C.convert_vertical_slash_indexes(
block_count, block_offset, column_count, column_index, q_seqlens,
kv_seqlens, vertical_indexes, slash_indexes, context_size,
block_size_M, block_size_N, causal)
return block_count, block_offset, column_count, column_index
def convert_vertical_slash_indexes_mergehead(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
# [N_HEADS] : different head use different number of indices
vertical_indices_count: torch.Tensor,
slash_indices_count: torch.Tensor,
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.empty(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
block_offset = torch.empty(batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_count = torch.empty(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_index = torch.empty(batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
torch.ops._C.convert_vertical_slash_indexes_mergehead(
block_count, block_offset, column_count, column_index, q_seqlens,
kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count,
slash_indices_count, context_size, block_size_M, block_size_N, causal)
return block_count, block_offset, column_count, column_index
# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,
......
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with Dual chunk flash attention and sparse attention.
"""
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch
import torch.distributed
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionImpl,
FlashAttentionMetadata,
FlashAttentionMetadataBuilder)
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.utils import async_tensor_h2d
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache, sparse_attn_func)
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
logger = init_logger(__name__)
class DualChunkFlashAttentionBackend(FlashAttentionBackend):
accept_output_buffer: bool = False
@staticmethod
def get_name() -> str:
return "DUAL_CHUNK_FLASH_ATTN"
@staticmethod
def get_impl_cls() -> Type["DualChunkFlashAttentionImpl"]:
return DualChunkFlashAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["DualChunkFlashAttentionMetadata"]:
return DualChunkFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["DualChunkFlashAttentionMetadataBuilder"]:
return DualChunkFlashAttentionMetadataBuilder
@dataclass
class DualChunkFlashAttentionMetadata(FlashAttentionMetadata):
# Block size of the paged kv cache.
block_size: int = 16
# Original max position embeddings.
original_max_position_embeddings: int = 0
# Chunk size
chunk_size: int = 8192
# Local size
local_size: int = 1024
# (batch_size,). The orig sequence length per sequence.
orig_seq_lens: Optional[List[int]] = None
# orig_seq_lens stored as a tensor.
orig_seq_lens_tensor: Optional[torch.Tensor] = None
# Length scaling factor
scaling_factor: Optional[torch.Tensor] = None
# (batch_size,). Sequence lengths for intra attention.
seq_lens_intra: Optional[torch.Tensor] = None
# Max sequence length for intra attention.
max_seq_len_intra: Optional[int] = None
# (batch_size, num_blocks). Block table for intra attention.
block_tables_intra: Optional[torch.Tensor] = None
# (batch_size,). Sequence lengths for succ attention.
seq_lens_succ: Optional[torch.Tensor] = None
# Max sequence length for succ attention.
max_seq_len_succ: Optional[int] = None
# (batch_size, num_blocks). Block table for succ attention.
block_tables_succ: Optional[torch.Tensor] = None
# (batch_size,). Sequence lengths for inter attention.
seq_lens_inter: Optional[torch.Tensor] = None
# Max sequence length for inter attention.
max_seq_len_inter: Optional[int] = None
_cached_prefill_metadata: Optional[
"DualChunkFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["DualChunkFlashAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
prefill_metadata = super().prefill_metadata
if prefill_metadata is None:
return None
prefill_metadata = DualChunkFlashAttentionMetadata(
**prefill_metadata.asdict_zerocopy())
prefill_metadata.orig_seq_lens = (
None if self.orig_seq_lens is None else
self.orig_seq_lens[:self.num_prefills])
prefill_metadata.orig_seq_lens_tensor = (
None if self.orig_seq_lens_tensor is None else
self.orig_seq_lens_tensor[:self.num_prefills])
if self.original_max_position_embeddings > 0:
assert prefill_metadata.orig_seq_lens_tensor is not None
prefill_metadata.scaling_factor = (
0.1 * torch.log(prefill_metadata.orig_seq_lens_tensor /
self.original_max_position_embeddings) +
1.0).clip(min=1)
self._cached_prefill_metadata = prefill_metadata
return prefill_metadata
@property
def decode_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
decode_metadata = super().decode_metadata
if decode_metadata is None:
return None
decode_metadata = DualChunkFlashAttentionMetadata(
**decode_metadata.asdict_zerocopy())
decode_metadata.orig_seq_lens_tensor = (
None if self.orig_seq_lens_tensor is None else
self.orig_seq_lens_tensor[self.num_prefills:])
assert decode_metadata.orig_seq_lens_tensor is not None
assert decode_metadata.block_tables is not None
cache_seq_lens = decode_metadata.orig_seq_lens_tensor
chunk_len = self.chunk_size - self.local_size
chunk_num_curr = (cache_seq_lens - 1) // chunk_len
batch_size = decode_metadata.num_decode_tokens
if self.original_max_position_embeddings > 0:
decode_metadata.scaling_factor = (0.1 * torch.log(
cache_seq_lens / self.original_max_position_embeddings) +
1.0).clip(min=1)
seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len
max_seq_len_intra = seq_lens_intra.max().item()
decode_metadata.seq_lens_intra = seq_lens_intra
decode_metadata.max_seq_len_intra = max_seq_len_intra
block_tables_intra = torch.zeros(
batch_size,
(max_seq_len_intra - 1) // self.block_size + 1,
dtype=decode_metadata.block_tables.dtype,
device=decode_metadata.block_tables.device,
)
for i in range(batch_size):
st = chunk_num_curr[i] * chunk_len // self.block_size
ed = min(
st + (max_seq_len_intra - 1) // self.block_size + 1,
(cache_seq_lens[i] - 1) // self.block_size + 1,
)
block_tables_intra[i, :ed -
st] = decode_metadata.block_tables[i, st:ed]
decode_metadata.block_tables_intra = block_tables_intra
seq_lens_succ = (chunk_num_curr -
(chunk_num_curr - 1).clip(min=0)) * chunk_len
max_seq_len_succ = seq_lens_succ.max().item()
decode_metadata.seq_lens_succ = seq_lens_succ
decode_metadata.max_seq_len_succ = max_seq_len_succ
if max_seq_len_succ:
block_tables_succ = torch.zeros(
batch_size,
(max_seq_len_succ - 1) // self.block_size + 1,
dtype=decode_metadata.block_tables.dtype,
device=decode_metadata.block_tables.device,
)
for i in range(batch_size):
start = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len //
self.block_size)
end = min(
start + (max_seq_len_succ - 1) // self.block_size + 1,
(cache_seq_lens[i] - 1) // self.block_size + 1,
)
block_tables_succ[
i, :end - start] = decode_metadata.block_tables[i,
start:end]
decode_metadata.block_tables_succ = block_tables_succ
seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len
max_seq_len_inter = seq_lens_inter.max().item()
decode_metadata.seq_lens_inter = seq_lens_inter
decode_metadata.max_seq_len_inter = max_seq_len_inter
self._cached_decode_metadata = decode_metadata
return decode_metadata
class DualChunkFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
def prepare(self):
super().prepare()
self.orig_seq_lens: List[int] = []
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
super()._add_seq_group(inter_data, chunked_prefill_enabled,
prefix_cache_hit)
for prompt_len, seq_len in zip(inter_data.prompt_lens,
inter_data.seq_lens):
self.orig_seq_lens.append(max(prompt_len, seq_len))
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
attn_metadata = super().build(seq_lens, query_lens,
cuda_graph_pad_size, batch_size)
attn_metadata = DualChunkFlashAttentionMetadata(
**attn_metadata.asdict_zerocopy())
device = self.runner.device
attn_metadata.orig_seq_lens = self.orig_seq_lens
attn_metadata.orig_seq_lens_tensor = async_tensor_h2d(
self.orig_seq_lens, torch.int, device, self.runner.pin_memory)
attn_metadata.block_size = self.runner.block_size
dual_chunk_attn_config = getattr(self.runner.model_config.hf_config,
"dual_chunk_attention_config", {})
attn_metadata.original_max_position_embeddings = \
dual_chunk_attn_config.get("original_max_position_embeddings", 0)
attn_metadata.chunk_size = dual_chunk_attn_config.get(
"chunk_size", 8192)
attn_metadata.local_size = dual_chunk_attn_config.get(
"local_size", 1024)
return attn_metadata
class DualChunkFlashAttentionImpl(FlashAttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
layer_idx: int = -1,
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if sliding_window is not None:
# NOTE(woosuk): flash-attn's sliding window does not work with
# paged KV cache.
raise ValueError(
"Sliding window is not supported in FlashAttention.")
support_head_sizes = (
DualChunkFlashAttentionBackend.get_supported_head_sizes())
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
assert dual_chunk_attention_config is not None
self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192)
self.local_size = dual_chunk_attention_config.get("local_size", 1024)
self.original_max_position_embeddings = dual_chunk_attention_config.get(
"original_max_position_embeddings", 0)
self.sparse_attention_config = dual_chunk_attention_config.get(
"sparse_attention_config", None)
if not self.sparse_attention_config:
logger.warning_once("Sparse attention will not be enabled as "
"sparse attention config is not provided.")
self.sparse_attention_enabled = dual_chunk_attention_config.get(
"sparse_attention_enabled", self.sparse_attention_config
is not None)
self.sparse_attention_threshold = dual_chunk_attention_config.get(
"sparse_attention_threshold", 32768)
self.sparse_attention_last_q = dual_chunk_attention_config.get(
"sparse_attention_last_q", 64)
self.layer_idx = layer_idx
self.dual_chunk_attention_config = dual_chunk_attention_config
if self.sparse_attention_config:
self.sparse_attention_config = {
int(i): j
for i, j in self.sparse_attention_config[
self.layer_idx].items()
}
start_head = self.num_heads * get_tensor_model_parallel_rank()
end_head = start_head + self.num_heads
self.sparse_attention_config = [
self.sparse_attention_config[i]
for i in range(start_head, end_head)
]
if self.sparse_attention_enabled:
self.arange = torch.arange(self.sparse_attention_last_q,
device="cuda")
self.last_q_mask = (self.arange[None, None, :, None]
>= self.arange[None, None, None, :])
def forward( # type: ignore
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: DualChunkFlashAttentionMetadata,
) -> torch.Tensor:
"""Forward pass with DualChunkFlashAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
query_succ: shape = [num_tokens, num_heads * head_size]
query_inter: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
(
query,
query_succ,
query_inter,
query_succ_critical,
query_inter_critical,
) = torch.split(query, query.shape[-1] // 5, dim=-1)
assert (
query_succ is not None and query_inter is not None
), "query_succ and query_inter are required in Dual Chunk Attention."
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
query_succ = query_succ.view(-1, self.num_heads, self.head_size)
query_inter = query_inter.view(-1, self.num_heads, self.head_size)
query_succ_critical = query_succ_critical.view(-1, self.num_heads,
self.head_size)
query_inter_critical = query_inter_critical.view(
-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.original_max_position_embeddings > 0:
if prefill_meta := attn_metadata.prefill_metadata:
assert prefill_meta.scaling_factor is not None
assert prefill_meta.query_start_loc is not None
assert prefill_meta.orig_seq_lens is not None
current_start = 0
query_start_loc_cpu = prefill_meta.query_start_loc.cpu()
for i in range(len(prefill_meta.orig_seq_lens)):
current_end = (current_start +
(query_start_loc_cpu[i + 1] -
query_start_loc_cpu[i]).item())
key[current_start:current_end].mul_(
prefill_meta.scaling_factor[i])
current_start = current_end
assert current_end <= attn_metadata.num_prefill_tokens
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta.scaling_factor is not None
scaling_factor = decode_meta.scaling_factor
key[attn_metadata.num_prefill_tokens:].mul_(
scaling_factor.unsqueeze(-1).unsqueeze(-1))
if kv_cache is not None and kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
decode_query_succ = query_succ[num_prefill_tokens:]
decode_query_inter = query_inter[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
query_succ = query_succ[:num_prefill_tokens]
query_inter = query_inter[:num_prefill_tokens]
query_succ_critical = query_succ_critical[:num_prefill_tokens]
query_inter_critical = query_inter_critical[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache is None or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention, called during the profiling run.
out = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
assert prefill_meta.orig_seq_lens is not None
output[:num_prefill_tokens] = (
self._dual_chunk_flash_attn_prefill(
q=query,
q_succ=query_succ,
q_inter=query_inter,
q_succ_critical=query_succ_critical,
q_inter_critical=query_inter_critical,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
orig_seq_lens=prefill_meta.orig_seq_lens,
scaling_factor=prefill_meta.scaling_factor,
softmax_scale=self.scale,
causal=True,
window_size=(-1, -1),
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
chunk_size=self.chunk_size,
local_size=self.local_size,
))
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[num_prefill_tokens:] = (
self._dual_chunk_flash_attn_decoding(
decode_query.unsqueeze(1),
decode_query_succ.unsqueeze(1),
decode_query_inter.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
chunk_size=self.chunk_size,
local_size=self.local_size,
original_max_position_embeddings=self.
original_max_position_embeddings,
decode_meta=decode_meta,
).squeeze(1))
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)
def _dual_chunk_flash_attn_prefill(
self,
q,
q_succ,
q_inter,
q_succ_critical,
q_inter_critical,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
orig_seq_lens: List[int],
scaling_factor: torch.Tensor,
softmax_scale: float,
causal: Optional[bool] = True,
window_size: Tuple[int, int] = (-1, -1),
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
chunk_size: int = 8192,
local_size: int = 1024,
):
if alibi_slopes is not None:
raise ValueError(
"Dual Chunk Attention does not support alibi_slopes")
if not causal:
raise ValueError(
"Dual Chunk Attention does not support causal=False")
if window_size != (-1, -1):
raise ValueError(
"Dual Chunk Attention does not support window_size")
cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist()
cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist()
all_outputs = []
for i in range(0, len(cu_seqlens_q_cpu) - 1):
qs = cu_seqlens_q_cpu[i]
qe = cu_seqlens_q_cpu[i:i + 2][-1]
ks = cu_seqlens_k_cpu[i]
ke = cu_seqlens_k_cpu[i:i + 2][-1]
current_q = q[qs:qe]
current_q_succ = q_succ[qs:qe]
current_q_inter = q_inter[qs:qe]
current_q_succ_critical = q_succ_critical[qs:qe]
current_q_inter_critical = q_inter_critical[qs:qe]
if block_table is None:
current_k = k[ks:ke]
current_v = v[ks:ke]
current_block_table = None
current_orig_seq_len = orig_seq_lens[i]
else:
current_block_table = block_table[i]
current_orig_seq_len = orig_seq_lens[i]
current_k = k
current_v = v
sparse_attn_enabled = (self.sparse_attention_enabled
and current_orig_seq_len
> self.sparse_attention_threshold)
if current_q.shape[0] == 0:
continue
if current_k.shape[0] == 0:
all_outputs.append(
torch.zeros(
(current_q.shape[0], current_q.shape[1], v.shape[2]),
device=q.device,
dtype=q.dtype,
))
continue
current_output = torch.empty_like(current_q)
group_size = int(current_q.size(-2) / current_k.size(-2))
if sparse_attn_enabled:
num_device_q_heads = current_q.size(-2)
heads_vertical_size = torch.empty(size=(num_device_q_heads, ),
dtype=torch.int32)
heads_slash_size = torch.empty(size=(num_device_q_heads, ),
dtype=torch.int32)
for head_id in range(current_q.size(-2)):
(
ty,
vertical_size,
slash_size,
_,
) = self.sparse_attention_config[head_id]
assert ty == "vertical_and_slash", "only support slash mode"
if vertical_size == 30:
vertical_size += 100
heads_vertical_size[head_id] = vertical_size
heads_slash_size[head_id] = slash_size
current_output = self._dual_chunk_flash_attn_prefill_func(
current_q, # allheads
current_q_succ,
current_q_inter,
current_q_succ_critical,
current_q_inter_critical,
current_k,
current_v,
current_block_table,
softmax_scale,
chunk_size,
local_size,
scaling_factor[i].item(),
ke - ks,
sparse_attn_enabled=sparse_attn_enabled,
heads_vertical_size=heads_vertical_size,
heads_slash_size=heads_slash_size,
group_size=group_size)
else:
for head_id in range(current_q.size(-2)):
# (seq_len, num_heads, head_size)
current_q_head = current_q[:, head_id, :].unsqueeze(1)
current_q_succ_head = \
current_q_succ[:, head_id, :].unsqueeze(1)
current_q_inter_head = \
current_q_inter[:, head_id, :].unsqueeze(1)
current_q_succ_head_critical = \
current_q_succ_critical[:, head_id, :].unsqueeze(1)
current_q_inter_head_critical = \
current_q_inter_critical[:, head_id, :].unsqueeze(1)
if block_table is not None:
current_k_head = current_k[..., head_id //
group_size, :].unsqueeze(2)
current_v_head = current_v[..., head_id //
group_size, :].unsqueeze(2)
else:
current_k_head = current_k[:, head_id, :].unsqueeze(1)
current_v_head = current_v[:, head_id, :].unsqueeze(1)
current_out = self._dual_chunk_flash_attn_prefill_func(
current_q_head,
current_q_succ_head,
current_q_inter_head,
current_q_succ_head_critical,
current_q_inter_head_critical,
current_k_head,
current_v_head,
current_block_table,
softmax_scale,
chunk_size,
local_size,
scaling_factor[i].item(),
ke - ks,
sparse_attn_enabled=sparse_attn_enabled,
)
current_output[:, head_id:head_id + 1, :] = current_out
all_outputs.append(current_output)
return torch.cat(all_outputs, dim=0)
def _dual_chunk_flash_attn_prefill_func(
self,
q,
q_succ,
q_inter,
q_succ_critical,
q_inter_critical,
k,
v,
block_table,
softmax_scale: float,
chunk_size: int,
local_size: int,
scaling_factor: float,
k_length: int,
sparse_attn_enabled: Optional[bool] = True,
heads_vertical_size=None,
heads_slash_size=None,
group_size=None,
):
flash_results = []
chunk_len = chunk_size - local_size
if block_table is not None:
block_size = v.shape[1]
if chunk_len % block_size != 0:
raise ValueError("chunk_len must be divisible by block_size.")
else:
block_size = 1
if self.original_max_position_embeddings > 0:
softmax_scale = softmax_scale * scaling_factor
begin = k_length - q.shape[0]
while begin < k_length:
flash_per_chunk = []
prev_chunk_end_pos = (begin // chunk_len) * chunk_len
next_chunk_end_pos = prev_chunk_end_pos + chunk_len
end = min(next_chunk_end_pos, k_length)
qbegin = begin - (k_length - q.shape[0])
qend = end - (k_length - q.shape[0])
qk_chunks = []
q_states_intra = q[qbegin:qend]
# choose critical token
if block_table is not None:
block_tables_intra = _get_block(block_table, block_size,
prev_chunk_end_pos, end)
k_states_intra = k[block_tables_intra].view(
-1, *k.shape[-2:])[:(end - prev_chunk_end_pos)]
v_states_intra = v[block_tables_intra].view(
-1, *v.shape[-2:])[:(end - prev_chunk_end_pos)]
else:
block_tables_intra = None
k_states_intra = k[prev_chunk_end_pos:end]
v_states_intra = v[prev_chunk_end_pos:end]
if sparse_attn_enabled:
last_q_size = min(qend - qbegin, self.sparse_attention_last_q)
_, num_device_k_heads, head_dim = k_states_intra.shape
k_states_intra = (k_states_intra.unsqueeze(2).repeat(
1, 1, group_size,
1).reshape(-1, num_device_k_heads * group_size, head_dim))
v_states_intra = (v_states_intra.unsqueeze(2).repeat(
1, 1, group_size,
1).reshape(-1, num_device_k_heads * group_size, head_dim))
qk_chunks.append(
(q_states_intra.transpose(0, 1)[:, -last_q_size:] *
softmax_scale) @ k_states_intra.permute(1, 2, 0))
if prev_chunk_end_pos - chunk_len >= 0:
q_states_succ = q_succ[qbegin:qend]
q_states_succ_critical = q_succ_critical[qbegin:qend]
if block_table is not None:
block_tables_succ = _get_block(
block_table, block_size,
prev_chunk_end_pos - chunk_len, prev_chunk_end_pos)
k_states_succ = k[block_tables_succ].view(
-1, *k.shape[-2:])[:chunk_len]
v_states_succ = v[block_tables_succ].view(
-1, *v.shape[-2:])[:chunk_len]
else:
k_states_succ = k[prev_chunk_end_pos -
chunk_len:prev_chunk_end_pos]
v_states_succ = v[prev_chunk_end_pos -
chunk_len:prev_chunk_end_pos]
if sparse_attn_enabled:
k_states_succ = (k_states_succ.unsqueeze(2).repeat(
1, 1, group_size,
1).reshape(-1, num_device_k_heads * group_size,
head_dim))
v_states_succ = (v_states_succ.unsqueeze(2).repeat(
1, 1, group_size,
1).reshape(-1, num_device_k_heads * group_size,
head_dim))
qk_chunks.append((q_states_succ_critical.transpose(
0, 1)[:, -last_q_size:] * softmax_scale)
@ k_states_succ.permute(1, 2, 0))
if prev_chunk_end_pos - chunk_len * 2 >= 0:
q_states_inter = q_inter[qbegin:qend]
q_states_inter_critical = q_inter_critical[qbegin:qend]
if block_table is not None:
block_tables_inter = _get_block(
block_table, block_size, 0,
prev_chunk_end_pos - chunk_len)
k_states_inter = k[block_tables_inter].view(
-1, *k.shape[-2:])[:(prev_chunk_end_pos - chunk_len)]
v_states_inter = v[block_tables_inter].view(
-1, *v.shape[-2:])[:(prev_chunk_end_pos - chunk_len)]
else:
k_states_inter = k[:prev_chunk_end_pos - chunk_len]
v_states_inter = v[:prev_chunk_end_pos - chunk_len]
if sparse_attn_enabled:
k_states_inter = (k_states_inter.unsqueeze(2).repeat(
1, 1, group_size,
1).reshape(-1, num_device_k_heads * group_size,
head_dim))
v_states_inter = (v_states_inter.unsqueeze(2).repeat(
1, 1, group_size,
1).reshape(-1, num_device_k_heads * group_size,
head_dim))
qk_chunks.append((q_states_inter_critical.transpose(
0, 1)[:, -last_q_size:] * softmax_scale)
@ k_states_inter.permute(1, 2, 0))
if sparse_attn_enabled:
reversed_qk = qk_chunks[::-1]
qk = torch.cat(reversed_qk, dim=-1)
qk[:, :, -last_q_size:] = torch.where(
self.last_q_mask[..., -last_q_size:,
-last_q_size:].to(qk.device),
qk[:, :, -last_q_size:], -torch.inf)
qk = F.softmax(qk, dim=-1, dtype=torch.float32)
vertical = qk.sum(-2, keepdim=True)
vertical[..., :30] = torch.inf
# Avoid sorting by using the min/max ints to fill the indexer
# buffers.
int32_max = torch.iinfo(torch.int32).max
int32_min = torch.iinfo(torch.int32).min
n_heads = qk.size()[0]
max_slash_topk = torch.max(heads_slash_size).item()
max_vertical_topk = torch.max(heads_vertical_size).item()
# store each head's slash topk, vertical topk
vertical = vertical.reshape((n_heads, -1))
# prevent out of range when prompt size < max_vertical_topk
max_vertical_topk = min(vertical.shape[-1], max_vertical_topk)
vertical_topk_buffer = torch.topk(vertical, max_vertical_topk,
-1).indices
slash_topk_buffer = torch.empty(size=(n_heads, max_slash_topk),
dtype=torch.int64,
device=qk.device)
for head_i in range(n_heads):
# (nqheads=1, lastq, k_len)
head_score = qk[head_i:head_i + 1, :, :]
slash_scores = _sum_all_diagonal_matrix(head_score)
if head_score.size(1) != 1:
# drop right up corner
slash_scores = slash_scores[..., :-last_q_size + 1]
slash_scores[..., -100:] = torch.inf
head_slash_size = heads_slash_size[head_i]
head_slash_size = min(head_slash_size, vertical.size(-1))
slash_topk = torch.topk(slash_scores, head_slash_size,
-1).indices
#(nheads, max_topk)
slash_topk_buffer[head_i, :head_slash_size] = slash_topk
# reset heads topk
heads_slash_size[head_i] = head_slash_size
heads_vertical_size[head_i] = min(
heads_vertical_size[head_i], max_vertical_topk)
# store
vertical_buffer = torch.full((n_heads, max_vertical_topk),
int32_max,
dtype=torch.int64,
device=q.device)
slash_buffer = torch.full((n_heads, max_slash_topk),
int32_min,
dtype=torch.int64,
device=q.device)
succ_vertical_buffer = torch.full((n_heads, max_vertical_topk),
int32_max,
dtype=torch.int64,
device=q.device)
succ_slash_buffer = torch.full((n_heads, max_slash_topk),
int32_min,
dtype=torch.int64,
device=q.device)
inter_vertical_buffer = torch.full(
(n_heads, max_vertical_topk),
int32_max,
dtype=torch.int64,
device=q.device)
inter_slash_buffer = torch.full((n_heads, max_slash_topk),
int32_min,
dtype=torch.int64,
device=q.device)
vertical_size_buffer = torch.empty(size=(n_heads, ),
dtype=torch.int32,
device=q.device)
slash_sizes_buffer = torch.empty(size=(n_heads, ),
dtype=torch.int32,
device=q.device)
succ_vertical_size_buffer = torch.empty(size=(n_heads, ),
dtype=torch.int32,
device=q.device)
succ_slash_sizes_buffer = torch.empty(size=(n_heads, ),
dtype=torch.int32,
device=q.device)
inter_vertical_size_buffer = torch.empty(size=(n_heads, ),
dtype=torch.int32,
device=q.device)
inter_slash_sizes_buffer = torch.empty(size=(n_heads, ),
dtype=torch.int32,
device=q.device)
for head_i in range(n_heads):
vertical_topk = vertical_topk_buffer[
head_i, :heads_vertical_size[head_i]]
# intra
intra_vertical_indices = vertical_topk[
vertical_topk >=
prev_chunk_end_pos] - prev_chunk_end_pos
if intra_vertical_indices.nelement() == 0:
intra_vertical_indices = torch.cat([
intra_vertical_indices,
torch.arange(0,
k_states_intra.size(0),
max(1,
k_states_intra.size(0) / 5),
dtype=torch.int32,
device=intra_vertical_indices.device)
])
slash_topk = slash_topk_buffer[
head_i, :heads_slash_size[head_i]]
intra_slash_indices = (
(qk.size(-1) - 1) -
slash_topk[slash_topk >= prev_chunk_end_pos])
# fill buffer
v_count = intra_vertical_indices.nelement()
s_count = intra_slash_indices.nelement()
vertical_size_buffer[head_i] = v_count
slash_sizes_buffer[head_i] = s_count
vertical_buffer[head_i, :v_count].copy_(
intra_vertical_indices)
slash_buffer[head_i, :s_count].copy_(intra_slash_indices)
# succ
if prev_chunk_end_pos - chunk_len >= 0:
succ_vertical_indices = vertical_topk[
(vertical_topk < prev_chunk_end_pos)
& (vertical_topk >= prev_chunk_end_pos -
chunk_len)] - (prev_chunk_end_pos - chunk_len)
# TODO: support no vertical
if succ_vertical_indices.nelement() == 0:
succ_vertical_indices = torch.cat([
succ_vertical_indices,
torch.arange(
0,
k_states_succ.size(0),
max(1,
k_states_succ.size(0) / 5),
dtype=torch.int32,
device=intra_vertical_indices.device)
])
succ_slash_indices = (
(prev_chunk_end_pos + (qend - qbegin) - 1) -
slash_topk[((slash_topk >=
(prev_chunk_end_pos - chunk_len)) &
(slash_topk < (prev_chunk_end_pos +
(qend - qbegin))))])
if succ_slash_indices.nelement() == 0:
succ_slash_indices = torch.cat([
succ_slash_indices,
torch.arange(
0,
k_states_succ.size(0),
max(1,
k_states_succ.size(0) / 5),
dtype=torch.int32,
device=intra_vertical_indices.device)
])
# fill buffer
v_count = succ_vertical_indices.nelement()
s_count = succ_slash_indices.nelement()
succ_vertical_size_buffer[head_i] = v_count
succ_slash_sizes_buffer[head_i] = s_count
succ_vertical_buffer[head_i, :v_count].copy_(
succ_vertical_indices)
succ_slash_buffer[head_i, :s_count].copy_(
succ_slash_indices)
if prev_chunk_end_pos - 2 * chunk_len >= 0:
inter_vertical_indices = vertical_topk[
vertical_topk < prev_chunk_end_pos - chunk_len]
if inter_vertical_indices.nelement() == 0:
inter_vertical_indices = torch.cat([
inter_vertical_indices,
torch.arange(
0,
k_states_inter.size(0),
max(1,
k_states_inter.size(0) / 5),
dtype=torch.int32,
device=intra_vertical_indices.device)
])
inter_slash_indices = (
(prev_chunk_end_pos - chunk_len +
(qend - qbegin) - 1) -
slash_topk[slash_topk < (prev_chunk_end_pos -
chunk_len +
(qend - qbegin))])
if inter_slash_indices.nelement() == 0:
inter_slash_indices = torch.cat([
inter_slash_indices,
torch.arange(
0,
k_states_inter.size(0),
max(1,
k_states_inter.size(0) / 5),
dtype=torch.int32,
device=intra_vertical_indices.device)
])
# fill buffer
v_count = inter_vertical_indices.nelement()
s_count = inter_slash_indices.nelement()
inter_vertical_size_buffer[head_i] = v_count
inter_slash_sizes_buffer[head_i] = s_count
inter_vertical_buffer[head_i, :v_count].copy_(
inter_vertical_indices)
inter_slash_buffer[head_i, :s_count].copy_(
inter_slash_indices)
else:
intra_vertical_indices, intra_slash_indices = None, None
succ_vertical_indices, succ_slash_indices = None, None
inter_vertical_indices, inter_slash_indices = None, None
if sparse_attn_enabled:
flash_result = self._do_flash_attn(
q_states_intra,
k_states_intra,
v_states_intra,
softmax_scale=softmax_scale,
causal=True,
block_table=block_table,
stage="intra",
vertical_indices=vertical_buffer,
slash_indices=slash_buffer,
vertical_indices_count=vertical_size_buffer,
slash_indices_count=slash_sizes_buffer,
mergehead_softmax_scale=softmax_scale,
sparse_attn_enabled=sparse_attn_enabled)
else:
flash_result = self._do_flash_attn(
q_states_intra,
k_states_intra,
v_states_intra,
softmax_scale=softmax_scale,
causal=True,
block_table=block_table,
stage="intra",
vertical_indices=intra_vertical_indices,
slash_indices=intra_slash_indices,
sparse_attn_enabled=sparse_attn_enabled)
flash_per_chunk.append(flash_result)
if prev_chunk_end_pos - chunk_len >= 0:
if sparse_attn_enabled:
flash_result = self._do_flash_attn(
q_states_succ,
k_states_succ,
v_states_succ,
softmax_scale=softmax_scale,
causal=False,
block_table=block_table,
stage="succ",
vertical_indices=succ_vertical_buffer,
slash_indices=succ_slash_buffer,
vertical_indices_count=succ_vertical_size_buffer,
slash_indices_count=succ_slash_sizes_buffer,
mergehead_softmax_scale=softmax_scale,
sparse_attn_enabled=sparse_attn_enabled)
else:
flash_result = self._do_flash_attn(
q_states_succ,
k_states_succ,
v_states_succ,
softmax_scale=softmax_scale,
causal=False,
block_table=block_table,
stage="succ",
vertical_indices=succ_vertical_indices,
slash_indices=succ_slash_indices,
sparse_attn_enabled=sparse_attn_enabled)
flash_per_chunk.append(flash_result)
if prev_chunk_end_pos - chunk_len * 2 >= 0:
if sparse_attn_enabled:
flash_result = self._do_flash_attn(
q_states_inter,
k_states_inter,
v_states_inter,
softmax_scale=softmax_scale,
causal=False,
block_table=block_table,
stage="inter",
vertical_indices=inter_vertical_buffer,
slash_indices=inter_slash_buffer,
vertical_indices_count=inter_vertical_size_buffer,
slash_indices_count=inter_slash_sizes_buffer,
mergehead_softmax_scale=softmax_scale,
sparse_attn_enabled=sparse_attn_enabled)
else:
flash_result = self._do_flash_attn(
q_states_inter,
k_states_inter,
v_states_inter,
softmax_scale=softmax_scale,
causal=False,
block_table=block_table,
stage="inter",
vertical_indices=inter_vertical_indices,
slash_indices=inter_slash_indices,
sparse_attn_enabled=sparse_attn_enabled)
flash_per_chunk.append(flash_result)
flash_results.append(flash_per_chunk)
begin = end
attn_output = self._merge_attn_outputs(flash_results)
del flash_results
return attn_output
def _do_flash_attn(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
softmax_scale: float,
causal: bool = True,
block_table: torch.Tensor = None,
max_seqlen_k: Optional[int] = None,
stage: str = "intra",
vertical_indices: Optional[torch.Tensor] = None,
slash_indices: Optional[torch.Tensor] = None,
vertical_indices_count: Optional[torch.Tensor] = None,
slash_indices_count: Optional[torch.Tensor] = None,
mergehead_softmax_scale: Optional[float] = None,
sparse_attn_enabled: Optional[bool] = False,
):
if max_seqlen_k is None:
max_seqlen_k = key_states.shape[0]
q_len = query_states.shape[0]
q_heads = query_states.shape[1]
h_dim = query_states.shape[-1]
if sparse_attn_enabled:
assert slash_indices is not None
if stage == "intra":
assert causal
else:
assert not causal
query_states = query_states.unsqueeze(0).transpose(1, 2)
key_states = key_states.unsqueeze(0).transpose(1, 2)
value_states = value_states.unsqueeze(0).transpose(1, 2)
q = query_states
k = key_states
v = value_states
if (vertical_indices_count is not None and \
slash_indices_count is not None):
assert mergehead_softmax_scale is not None
res, s_lse = _vertical_slash_sparse_attention(
q,
k,
v,
vertical_indices,
slash_indices,
mergehead_softmax_scale,
causal=causal,
stage=stage,
vertical_indices_count=vertical_indices_count,
slash_indices_count=slash_indices_count)
res = res.view(q_heads, q_len,
h_dim).transpose(0, 1) # (qlen,nhead,h_dim)
s_lse = s_lse.view(
q_heads, q_len,
1).squeeze(-1).unsqueeze(0).float() # (1, nhead,qlen)
else:
res, s_lse = _vertical_slash_sparse_attention(q,
k,
v,
vertical_indices,
slash_indices,
softmax_scale,
causal=causal,
stage=stage)
res = res.view(q_len, q_heads, h_dim)
s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float()
return res, s_lse
output, softmax_lse = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
softmax_scale=softmax_scale,
cu_seqlens_q=torch.tensor([0, query_states.shape[0]],
dtype=torch.int32,
device=query_states.device),
max_seqlen_q=query_states.shape[0],
cu_seqlens_k=torch.tensor([0, max_seqlen_k],
dtype=torch.int32,
device=query_states.device),
max_seqlen_k=max_seqlen_k,
causal=causal,
block_table=block_table.unsqueeze(0),
return_softmax_lse=True,
)
softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0,
2).float()
return output, softmax_lse
def _merge_attn_outputs(
self,
flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]],
return_lse: Optional[bool] = False,
) -> torch.Tensor:
attn_outputs_all = []
logits_all = []
for flash_per_chunk in flash_results:
if len(flash_per_chunk) == 1:
attn_outputs_all.append(flash_per_chunk[0][0])
if return_lse:
logits_all.append(flash_per_chunk[0][1])
continue
attn_outputs = torch.stack([
flash_attn_output[0] for flash_attn_output in flash_per_chunk
])
logits = torch.stack([
flash_attn_output[1] for flash_attn_output in flash_per_chunk
])
logits = logits.to(torch.float32)
if return_lse:
max_val = torch.max(logits, dim=0).values
diff = torch.abs(logits[0] - logits[1])
log_sum_exp = max_val + torch.log1p(torch.exp(-diff))
logits_all.append(log_sum_exp)
max_logits = torch.max(logits, dim=0).values
stable_logits = logits - max_logits.unsqueeze(0)
lse_s = torch.exp(stable_logits).detach()
lse_sum = torch.sum(lse_s, dim=0)
lse_s /= lse_sum
attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1)
attn_outputs_all.append(attn_outputs.sum(dim=0))
if return_lse:
return (torch.cat(attn_outputs_all,
dim=0), torch.cat(logits_all, dim=-1))
else:
return torch.cat(attn_outputs_all, dim=0)
def _dual_chunk_flash_attn_decoding(
self,
query: torch.Tensor,
query_succ: torch.Tensor,
query_inter: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
softmax_scale: float,
causal: bool,
alibi_slopes: Optional[torch.Tensor],
chunk_size: int,
local_size: int,
original_max_position_embeddings: int,
decode_meta: DualChunkFlashAttentionMetadata,
):
if not causal:
raise ValueError(
"Dual Chunk Attention does not support causal=False")
block_size = value_cache.shape[1]
chunk_len = chunk_size - local_size
if chunk_len % block_size != 0:
raise ValueError("chunk_len must be divisible by block_size.")
if original_max_position_embeddings > 0:
assert decode_meta.scaling_factor is not None
scaling_factor = decode_meta.scaling_factor
query = (query * scaling_factor.view(-1, 1, 1, 1)).to(
query.dtype
) # possible for numerical issue, need to fused in the kernel
query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to(
query.dtype)
query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to(
query.dtype)
outputs_list = []
softmax_lses_list = []
# intra-attention
intra_output, intra_softmax_lse = (
self._dual_chunk_flash_attn_decoding_with_exp_sums(
query,
key_cache,
value_cache,
decode_meta.block_tables_intra,
decode_meta.seq_lens_intra,
softmax_scale,
alibi_slopes,
causal=False,
))
outputs_list.append(intra_output)
softmax_lses_list.append(intra_softmax_lse)
# succ-attention
if decode_meta.max_seq_len_succ:
succ_output, succ_softmax_lse = (
self._dual_chunk_flash_attn_decoding_with_exp_sums(
query_succ,
key_cache,
value_cache,
decode_meta.block_tables_succ,
decode_meta.seq_lens_succ,
softmax_scale,
alibi_slopes,
causal=False,
))
outputs_list.append(succ_output)
softmax_lses_list.append(succ_softmax_lse)
# inter-attention
if decode_meta.max_seq_len_inter:
inter_output, inter_softmax_lse = (
self._dual_chunk_flash_attn_decoding_with_exp_sums(
query_inter,
key_cache,
value_cache,
block_table[:, :decode_meta.max_seq_len_inter],
decode_meta.seq_lens_inter,
softmax_scale,
alibi_slopes,
causal=False,
))
outputs_list.append(inter_output)
softmax_lses_list.append(inter_softmax_lse)
outputs = torch.stack(outputs_list, dim=0)
del outputs_list
softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32)
del softmax_lses_list
max_logits = torch.max(softmax_lses, dim=0).values
stable_logits = softmax_lses - max_logits.unsqueeze(0)
lse_s = torch.exp(stable_logits).detach()
lse_sum = torch.sum(lse_s, dim=0)
lse_s /= lse_sum
outputs *= lse_s.unsqueeze(-1).transpose(2, 3)
return outputs.sum(0)
def _dual_chunk_flash_attn_decoding_with_exp_sums(
self,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
softmax_scale: float,
alibi_slopes: Optional[torch.Tensor],
causal: bool,
):
out, softmax_lse = flash_attn_with_kvcache(
q=query,
k_cache=key_cache,
v_cache=value_cache,
block_table=block_table,
cache_seqlens=cache_seqlens,
softmax_scale=softmax_scale,
alibi_slopes=alibi_slopes,
causal=causal,
return_softmax_lse=True,
)
mask = (cache_seqlens == 0)
out[mask] = 0
softmax_lse[mask] = -float("inf")
return out, softmax_lse
def _vertical_slash_sparse_attention(
query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]
key: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD]
value: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD]
v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
softmax_scale: float,
causal: bool = True,
stage: str = "intra",
block_size_M: int = 64,
block_size_N: int = 64,
vertical_indices_count: torch.Tensor = None, # [N_HEADS,]
slash_indices_count: torch.Tensor = None,
):
if stage == "intra":
assert causal
else:
assert not causal
batch_size, num_heads, context_size, head_dim = query.shape
_, _, kv_seq_len, _ = key.shape
if head_dim not in [16, 32, 64, 128, 256, 512]:
target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim
query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0])
key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0])
value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0])
v_idx = v_idx.to(torch.int32).reshape(
(batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0]
s_idx = s_idx.to(torch.int32).reshape(
(batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0]
q_seqlens = torch.tensor([context_size],
dtype=torch.int32,
device=query.device)
kv_seqlens = torch.tensor([kv_seq_len],
dtype=torch.int32,
device=query.device)
if vertical_indices_count is not None and slash_indices_count is not None:
(
block_count,
block_offset,
column_count,
column_index,
) = ops.convert_vertical_slash_indexes_mergehead(
q_seqlens, kv_seqlens, v_idx, s_idx, vertical_indices_count,
slash_indices_count, context_size, block_size_M, block_size_N,
causal)
else:
(
block_count,
block_offset,
column_count,
column_index,
) = ops.convert_vertical_slash_indexes(q_seqlens, kv_seqlens, v_idx,
s_idx, context_size,
block_size_M, block_size_N,
causal)
q = query.transpose(1, 2).contiguous()
k = key.transpose(1, 2).contiguous()
v = value.transpose(1, 2).contiguous()
out, lse = sparse_attn_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
causal=causal,
softmax_scale=softmax_scale,
return_softmax_lse=True,
)
out = out.transpose(1, 2).contiguous()
softmax_lse = lse.reshape(*lse.shape, 1)
return (out[..., :context_size, :head_dim],
softmax_lse[..., :context_size, :])
def _sum_all_diagonal_matrix(mat: torch.tensor):
h, n, m = mat.shape
# Zero matrix used for padding
zero_mat = torch.zeros((h, n, n), device=mat.device)
# pads the matrix on left and right
mat_padded = torch.cat((zero_mat, mat, zero_mat), -1)
# Change the strides
mat_strided = mat_padded.as_strided((1, n, n + m),
(n * (2 * n + m), 2 * n + m + 1, 1))
# Sums the resulting matrix's columns
sum_diags = torch.sum(mat_strided, 1)
return sum_diags[:, 1:] # drop left bottom corner
def _get_block(block_table: torch.Tensor, block_size: int, begin: int,
end: int):
begin_block = begin // block_size
end_block = (end - 1) // block_size + 1
return block_table[begin_block:end_block]
......@@ -929,6 +929,23 @@ class ModelConfig:
"Number of experts in the model must be greater than 0 "
"when expert parallelism is enabled.")
def verify_dual_chunk_attention_config(
self,
load_config: "LoadConfig",
) -> None:
if hasattr(self.hf_config, "dual_chunk_attention_config"):
# Try loading the sparse attention config
from vllm.model_executor.model_loader.weight_utils import (
get_sparse_attention_config)
sparse_attn_config = get_sparse_attention_config(self, load_config)
if sparse_attn_config:
self.hf_config.dual_chunk_attention_config[
"sparse_attention_config"] = sparse_attn_config
if "sparse_attention_enabled" not in \
self.hf_config.dual_chunk_attention_config:
self.hf_config.dual_chunk_attention_config[
"sparse_attention_enabled"] = True
def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None:
if not self.use_async_output_proc:
......@@ -4187,6 +4204,8 @@ class VllmConfig:
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config)
self.model_config.verify_dual_chunk_attention_config(
self.load_config)
if self.cache_config is not None:
self.cache_config.verify_with_parallel_config(self.parallel_config)
......
......@@ -37,8 +37,8 @@ from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build,
is_in_ray_actor)
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, is_in_doc_build, is_in_ray_actor)
# yapf: enable
......@@ -983,6 +983,17 @@ class EngineArgs:
assert self.enable_chunked_prefill is not None
if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
assert self.enforce_eager, (
"Cuda graph is not supported with DualChunkFlashAttention. "
"To run the model in eager mode, set 'enforce_eager=True' "
"or use '--enforce-eager' in the CLI.")
assert current_platform.is_cuda(), (
"DualChunkFlashAttention is only supported on CUDA platform.")
assert not use_v1, (
"DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
......
......@@ -1486,6 +1486,184 @@ class MRotaryEmbedding(RotaryEmbedding):
return updates
@CustomOp.register("dual_chunk_rotary_embedding")
class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
chunk_size: int,
local_size: int,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.chunk_size = chunk_size
self.local_size = local_size
self.dtype = dtype
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
(q_cache, qc_cache, k_cache, qc_no_clamp_cache,
q_inter_cache) = self._compute_cos_sin_cache()
self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
self.register_buffer("cos_sin_qc_no_clamp_cache",
qc_no_clamp_cache,
persistent=False)
self.register_buffer("cos_sin_q_inter_cache",
q_inter_cache,
persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
chunk_len = self.chunk_size - self.local_size
q_t = torch.arange(chunk_len, dtype=torch.float)
qc_t = (torch.arange(chunk_len, dtype=torch.float) +
chunk_len).clamp(max=self.chunk_size)
k_t = torch.arange(self.max_position_embeddings,
dtype=torch.float) % chunk_len
# count from chunk_len, no clamp(self.chunk_size) restriction
qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
# count from self.chunk_size for q_inter's rope
q_inter_t = torch.arange(chunk_len,
dtype=torch.float) + self.chunk_size
q_freqs = torch.outer(q_t, inv_freq)
qc_freqs = torch.outer(qc_t, inv_freq)
k_freqs = torch.outer(k_t, inv_freq)
qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
q_inter_freqs = torch.outer(q_inter_t, inv_freq)
q_cos = q_freqs.cos()
q_sin = q_freqs.sin()
qc_cos = qc_freqs.cos()
qc_sin = qc_freqs.sin()
k_cos = k_freqs.cos()
k_sin = k_freqs.sin()
qc_no_clamp_cos = qc_no_clamp_freqs.cos()
qc_no_clamp_sin = qc_no_clamp_freqs.sin()
q_inter_cos = q_inter_freqs.cos()
q_inter_sin = q_inter_freqs.sin()
q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype,
device=self.device)
qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype,
device=self.device)
k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype,
device=self.device)
qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin),
dim=-1).to(dtype=self.dtype,
device=self.device)
q_inter_cache = torch.cat((q_inter_cos, q_inter_sin),
dim=-1).to(dtype=self.dtype,
device=self.device)
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
else:
query_pass = None
key_pass = None
positions_with_offsets = (torch.add(positions, offsets)
if offsets is not None else positions)
key = self._apply_rotary_embedding(
self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass)
chunk_len = self.chunk_size - self.local_size
query = self._apply_rotary_embedding(
self.cos_sin_q_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
query_succ = self._apply_rotary_embedding(
self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
query_inter = self._apply_rotary_embedding(
self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
query_rot, query_pass)
query_succ_critical = self._apply_rotary_embedding(
self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
query_inter_critical = self._apply_rotary_embedding(
self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
# merge query into one tensor to simplify the interfaces
query = torch.cat((
query,
query_succ,
query_inter,
query_succ_critical,
query_inter_critical,
),
dim=-1)
return query, key
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
if self.rotary_dim < self.head_size:
hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
else:
hidden = hidden_rot
return hidden.flatten(-2).squeeze(0)
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
return s
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
......@@ -1498,6 +1676,7 @@ def get_rope(
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
......@@ -1510,14 +1689,35 @@ def get_rope(
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in dual_chunk_attention_config.items()
if k != "sparse_attention_config"
}
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
else:
dual_chunk_attention_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dtype)
rope_scaling_args, dual_chunk_attention_args, dtype)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if not rope_scaling:
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
for k, v in dual_chunk_attention_config.items()
if k in ("chunk_size", "local_size")
}
rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype,
**extra_kwargs)
elif not rope_scaling:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:
......
......@@ -217,6 +217,39 @@ def get_quant_config(model_config: ModelConfig,
return quant_cls.from_config(config)
def get_sparse_attention_config(
model_config: ModelConfig,
load_config: LoadConfig,
sparse_attention_config_filename: str = "sparse_attention_config.json",
) -> Dict[str, Any]:
model_name_or_path = model_config.model
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(
model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
tqdm_class=DisabledTqdm,
)
else:
hf_folder = model_name_or_path
config_file = os.path.join(hf_folder, sparse_attention_config_filename)
if not os.path.exists(config_file):
return {}
# Load the sparse attention config.
with open(config_file) as f:
config = json.load(f)
logger.info("Loaded sparse attention config from %s", config_file)
return config
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
......
......@@ -23,7 +23,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from typing import Any, Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
is_pp_missing_parameter,
extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -99,7 +99,8 @@ class Qwen2MLP(nn.Module):
class Qwen2Attention(nn.Module):
def __init__(self,
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
......@@ -109,7 +110,9 @@ class Qwen2Attention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str,
Any]] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
......@@ -131,6 +134,7 @@ class Qwen2Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(
hidden_size,
......@@ -155,15 +159,21 @@ class Qwen2Attention(nn.Module):
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(self.num_heads,
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
attn_type=attn_type,
prefix=f"{prefix}.attn",
attn_type=attn_type)
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {})
def forward(
self,
......@@ -192,6 +202,9 @@ class Qwen2DecoderLayer(nn.Module):
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
......@@ -213,6 +226,7 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
......
......@@ -175,6 +175,7 @@ class Qwen2MoeAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -198,6 +199,7 @@ class Qwen2MoeAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(
hidden_size,
......@@ -221,14 +223,20 @@ class Qwen2MoeAttention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(self.num_heads,
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {})
def forward(
self,
......@@ -256,6 +264,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = Qwen2MoeAttention(
......@@ -268,6 +279,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
dual_chunk_attention_config=dual_chunk_attention_config,
)
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
......
......@@ -222,6 +222,10 @@ class CudaPlatformBase(Platform):
elif selected_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend"
elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN:
pass
elif selected_backend:
......
......@@ -51,6 +51,7 @@ class _Backend(enum.Enum):
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
......
......@@ -153,6 +153,7 @@ STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
GB_bytes = 1_000_000_000
......
......@@ -204,6 +204,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.mrope_input_positions = None # type: ignore
self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore
self.prompt_lens[0] = 0 # type: ignore
self.query_lens[0] = 0 # type: ignore
self.context_lens[0] = 0 # type: ignore
self.curr_sliding_window_blocks[0] = 0 # type: ignore
......@@ -236,6 +237,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: Optional[List[int]] = None,
# This is used in the dual-chunk flash attention backend.
prompt_lens: Optional[List[int]] = None,
# The query length.
query_lens: Optional[List[int]] = None,
# The number of tokens that are already computed.
......@@ -316,6 +319,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)):
self.orig_seq_lens[seq_id] = 0
if prompt_lens:
self.prompt_lens = prompt_lens
else:
for seq_id in range(len(self.seq_ids)):
self.prompt_lens[seq_id] = 0
if query_lens:
self.query_lens = query_lens
else:
......@@ -370,6 +379,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.mrope_input_positions = mrope_input_positions or None
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
self.prompt_lens = prompt_lens or []
self.query_lens = query_lens or []
self.context_lens = context_lens or []
self.curr_sliding_window_blocks = \
......@@ -403,6 +413,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs
self.prompt_lens = [0] * self.n_seqs
self.query_lens = [0] * self.n_seqs
self.context_lens = [0] * self.n_seqs
self.curr_sliding_window_blocks = [0] * self.n_seqs
......@@ -552,6 +563,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len()
inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.inputs_embeds = prompt_embeds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment