Unverified Commit 4f33ece4 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add KV cache for paged/non-paged attention (#1355)



* add paged attention; test_kv_cache_accuray and test_paged_attn pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unnecessary change from last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test_fused_attn pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove unnecessary import in test_numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add license for test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add to L0 test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update license for test_paged_attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update kv_cache_manager license
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix build issue from previous merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: minor fix/preparation for inference/cuda graph
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, bshd/sbhd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, thd, no CG
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: non-paged, thd, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, using paged kernel
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: restructure kernels
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: paged, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: padding + BRCM
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: restructure IP, clean up
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix non-CG, fused
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix last commit
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: unfused, non-CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash-attn, non-CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash_attn_with_kvcache
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* commit two files missed by bcef6b34
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: thd_bshd_bshd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix last commit
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix 1c31b68d
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: add bshd_2sbhd, sbhd_2bshd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: some cleanup
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: all qkv_format combinations and merge CM files
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: some lint fixes
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: add docstring for IP
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix sequences_pre
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: minor fixes for multi-layer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: initial multi-layer test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: switch to flash_attn_varlen_func
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix unfused for separate q/kv format
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix fused for separate q/kv formats
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash attn + TELayer + 2 layers
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: unfused + TL + 2layers
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: all modules/backend
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: minor cleanup
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: FlashAttention on Hopper with 2.7.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: FlashAttention + v3 from 39e7179
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: FlashAttention + v3 + FP8 + WIP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add backend support table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: separate use_flash_attention_2 and _3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: tweaks to paged attn script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: enable/disable certain cases for fused attn
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: small fixes for lint and cg
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: minor fixes for attn/infer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix CP
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: readd page info to FADescriptor_v1
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweak to test_numerics.py
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix 9.5/9.7 sq/skv + mask logic
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* clean up
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more minor fixes for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test page_size=1 for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix t3hd/th3d strides
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ckpt recompute and fa3 k_scale
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* raise dynamo recompile limit for test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove thunder test from L0
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix FA selection logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FA3 q_descale shape
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove page_table from IP.step() returns
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FP8 FlashAttn DPA fp8_dpa tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweaks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA3 note and L3 test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove redundant import in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adopt new FA3 APIs from FA2.7.3+/hopper for CP and non-CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* relax tols for TransformerLayers
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix merge 2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FA import comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* relax tols for Ampere
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fa3 version and reduce messaging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA3 to its latest commit on main
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add default values to IP and assertion to graph.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more comments in attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use custom_cache_manager instead of cache_manager
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 05f6a691
......@@ -174,16 +174,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype"));
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"),
py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
py::call_guard<py::gil_scoped_release>());
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
......@@ -191,6 +183,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding",
py::call_guard<py::gil_scoped_release>());
// attention kernels
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("copy_to_kv_cache", &copy_to_kv_cache, "Copy new KV tokens to KV cache");
m.def("convert_thd_to_bshd", &convert_thd_to_bshd, "Convert a tensor from THD to BSHD");
m.def("convert_bshd_to_thd", &convert_bshd_to_thd, "Convert a tesnor from BSHD to THD");
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD",
py::call_guard<py::gil_scoped_release>());
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace transformer_engine {
namespace fused_attn {
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
template <typename scalar_t>
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
template <typename scalar_t>
__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
int d_v, int b, int max_seq_len) {
// k_cache, v_cache: bshd
// batch_indices: [b]; cu_new_lens, cu_cached_lens: [b + 1]
int actual_b = b;
for (int i = 0; i < b - 1; i++) {
if (batch_indices[i + 1] < batch_indices[i]) {
actual_b = i + 1;
}
}
for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) {
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; token_idx += gridDim.x) {
int num_elts_k = h_kv * d_k;
int num_elts_v = h_kv * d_v;
int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k;
int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k;
int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v;
int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v;
for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) {
*(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i);
}
for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) {
*(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i);
}
}
}
}
template <typename scalar_t>
__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache,
scalar_t *v_cache, int *page_table, int *cu_new_lens,
int *cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
// new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1]
// page_table: [b, max_pages_per_seq]
int page_size = max_seq_len / max_pages_per_seq;
if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int new_token_offset = batch_idx * max_ctx_len;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) =
*(new_k + (new_token_offset + i) * h_kv * d_k + j);
}
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) =
*(new_v + (new_token_offset + i) * h_kv * d_v + j);
}
}
}
} else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j);
}
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v + j);
}
}
}
} else if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq;
int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx];
int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx];
for (int i = threadIdx.x; i < new_len; i += blockDim.x) {
int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size];
int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size;
for (int j = 0; j < h_kv * d_k; j++) {
*(k_cache + token_idx * h_kv * d_k + j) =
*(new_k + (cu_new_lens[batch_idx] + i) * h_kv * d_k + j);
}
for (int j = 0; j < h_kv * d_v; j++) {
*(v_cache + token_idx * h_kv * d_v + j) =
*(new_v + (cu_new_lens[batch_idx] + i) * h_kv * d_v + j);
}
}
}
}
}
} // namespace fused_attn
} // namespace transformer_engine
#endif
......@@ -2,52 +2,797 @@
#
# See LICENSE for license information.
"""
Inference classes for attention
"""
"""Inference"""
import logging
from collections import OrderedDict, defaultdict
from typing import Optional, List
from einops import rearrange
import torch
class InferenceParams: # pylint: disable=too-few-public-methods
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat
__all__ = ["InferenceParams", "KVCacheManager", "NonPagedKVCacheManager", "PagedKVCacheManager"]
class KVCacheManager:
"""Base KV cache manager"""
def __init__(self):
"""Initialize cache manager"""
self.cache = {}
self.sequences = OrderedDict()
def reset(self):
"""Reset cache manager state"""
self.sequences = OrderedDict()
def allocate_memory(self, layer_number: int):
"""Allocate memory for the cache"""
self.cache[layer_number] = (None, None)
def pre_step(
self,
step_dict: OrderedDict, # pylint: disable=unused-argument
):
"""Update tracked sequences and prepare for step()"""
return self.sequences
def step(
self,
layer_number: int,
new_k: torch.Tensor, # pylint: disable=unused-argument
new_v: torch.Tensor, # pylint: disable=unused-argument
cu_new_seqlens: torch.Tensor, # pylint: disable=unused-argument
cu_cached_seqlens: torch.Tensor, # pylint: disable=unused-argument
qkv_format: str, # pylint: disable=unused-argument
):
"""Copy the new tokens to KV cache"""
return self.cache[layer_number]
class InferenceParams:
"""
Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference.
KV caching for inference. The memory allocation of the caches and the copying of new tokens
to the cache take place at the following locations.::
class TransformerLayer:
class MultiHeadAttention:
if self.layer_number not in inference_params.cache_manager.cache:
inference_params.allocate_memory(self.layer_number)
class DotProductAttention:
if inference_params is not None:
k_cache, v_cache, new_qkv_format = inference_params.step(
new_k, new_v, qkv_format)
output = attention(new_q, k_cache, v_cache, new_qkv_format)
allocate_memory() can be called outside the model, independently. step() can take three formats,
qkv_format = {'bshd', 'sbhd', 'thd'}. It converts new_k and new_v to 'bshd' in both
NonPagedKVCacheManager and PagedKVCacheManager. The format of new_q may change depending on the
backend. If it is unchanged, we would have new_qkv_format = {'bshd', 'sbhd_2bshd', 'thd_2bshd'}.
A standard KV caching workflow for inference is as follows.::
model = [TransformerLayer() for _ in range(num_layers)]
# initialize InferenceParams, e.g. with PagedKVCacheManager
inference_params = InferenceParams(..., is_paged=True)
# inference loop
for i in range(num_iters):
# get info for iteration i, e.g. seq_ids = [0, 2, 3], step_lens = [10, 1, 1]
step_dict = OrderedDict(zip(seq_ids, step_lens))
# update inference_params' state
inference_params.pre_step(step_dict)
# run iteration
output = model(
...,
attn_mask_type="padding_causal",
cu_seqlens_q=cu_seqlens_new_q,
cu_seqlens_kv=cu_seqlens_new_kv,
inference_params=inference_params,
)
# get output tokens based on qkv_format
# 'bshd': output = output[:,step_dict.values()-1]
# 'sbhd': output = output[step_dict.values()-1,:]
# 'thd' : output = output[cu_seqlens_new_q[j+1]-1], j=0,...b-1
Parameters
----------
max_batch_size : int
maximum batch size during inference.
max_sequence_length : int
maximum sequence length during inference.
max_batch_size: int
Maximum batch size in inference
max_seqlen_kv: int
Maximum sequence length in inference
num_heads_kv: int
Number of attention heads in keys and values
head_dim_k: int
Head size for keys
dtype: torch.dtype
Data type of the KV cache
head_dim_v: int, default = None
Head size for values. If None, initialized as head_dim_k.
is_paged: bool, default = False
Whether the KV cache is paged (True) or non-paged (False)
total_num_pages: int, default = None
Total number of pages in the KV cache. Required for is_paged = True.
page_size: int, default = None
Page size of the KV cache. Required for is_paged = True.
max_ctx_len: int, default = None
Maximum context length in inference. 1 <= max_ctx_len <= max_seqlen_kv.
qkv_format: str, default = "bshd"
Format of the incoming query/key/value tensors in current iteration
custom_cache_manager: KVCacheManager, default = None
Custom cache manager, with KVCacheManager as the base class.
"""
def __init__(self, max_batch_size, max_sequence_length):
self.max_sequence_length = max_sequence_length
def __init__(
self,
max_batch_size: int,
max_seqlen_kv: int,
num_heads_kv: int = 16,
head_dim_k: int = 64,
dtype: torch.dtype = torch.bfloat16,
head_dim_v: int = None,
is_paged: bool = False,
total_num_pages: int = None,
page_size: int = None,
max_ctx_len: int = None,
qkv_format: str = "bshd",
custom_cache_manager: KVCacheManager = None,
):
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
self.max_seqlen_kv = max_seqlen_kv
self.num_heads_kv = num_heads_kv
self.head_dim_k = head_dim_k
self.dtype = dtype
self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k
self.is_paged = is_paged
if not self.is_paged:
cache_manager = (
custom_cache_manager if custom_cache_manager is not None else NonPagedKVCacheManager
)
self.cache_manager = cache_manager(
max_batch_size=self.max_batch_size,
max_seqlen=self.max_seqlen_kv,
num_heads=self.num_heads_kv,
head_dim_k=self.head_dim_k,
dtype=self.dtype,
head_dim_v=self.head_dim_v,
)
else:
assert page_size is not None, "Paged KV cache requires page_size is not None."
self.page_size = page_size
assert (
max_seqlen_kv % page_size == 0
), "Paged KV cache requires max_seqlen_kv % page_size = 0."
max_pages_per_seq = max_seqlen_kv // page_size
assert (
total_num_pages == self.max_batch_size * max_pages_per_seq
), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq."
self.total_num_pages = total_num_pages
cache_manager = (
custom_cache_manager if custom_cache_manager is not None else PagedKVCacheManager
)
self.cache_manager = cache_manager(
total_num_pages=self.total_num_pages,
page_size=self.page_size,
num_heads=self.num_heads_kv,
head_dim_k=self.head_dim_k,
dtype=self.dtype,
max_batch_size=self.max_batch_size,
max_seqlen=self.max_seqlen_kv,
head_dim_v=self.head_dim_v,
)
if qkv_format == "thd":
assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!"
self.max_ctx_len = max_ctx_len
self.cache_qkv_format = "bshd"
self.input_qkv_format = qkv_format
if self.input_qkv_format == self.cache_qkv_format:
self.output_qkv_format = self.cache_qkv_format
else:
self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format
self.sequences_pre_step = OrderedDict()
self.sequences = OrderedDict()
self.batch_size = 0
self.cu_seqlens_q = torch.zeros(
self.max_batch_size + 1,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
self.cu_seqlens_kv = torch.zeros(
self.max_batch_size + 1,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
def reset(self):
"""Reset InferenceParams state"""
self.sequences = OrderedDict()
self.cache_manager.reset()
def __repr__(self) -> str:
if self.is_paged:
return (
f"dtype={self.dtype}, "
f"is_paged={self.is_paged}, "
f"total_pages={self.total_num_pages}, "
f"page_size={self.page_size}, "
f"num_heads={self.num_heads_kv}, "
f"head_dim_k={self.head_dim_k}, "
f"head_dim_v={self.head_dim_v}"
)
return (
f"dtype={self.dtype}, "
f"is_paged={self.is_paged}, "
f"max_batch_size={self.max_batch_size}, "
f"max_seqlen={self.max_seqlen_kv}, "
f"num_heads={self.num_heads_kv}, "
f"head_dim_k={self.head_dim_k}, "
f"head_dim_v={self.head_dim_v}"
)
def swap_key_value_dict(self, batch_indices):
def allocate_memory(self, layer_number: int):
"""
Reorders the KV cache using the specified batch indices.
Allocate memory for the cache. For layer layer_number,
- NonPagedKVCacheManager:
- K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k]
- V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v]
- PagedKVCacheManager:
- K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k]
- V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v]
"""
self.cache_manager.allocate_memory(layer_number)
def pre_step(
self,
step_dict: OrderedDict,
):
"""Update tracked sequences and prepare for step()"""
self.batch_size = len(step_dict)
self.sequences = self.cache_manager.pre_step(step_dict)
# track the pre-step seqlens for the next layer in the model
self.sequences_pre_step = OrderedDict()
for k, v in self.sequences.items():
self.sequences_pre_step[k] = v - step_dict[k]
seqlens_q = list(step_dict.values())
cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)]
cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size)
self.cu_seqlens_q.copy_(torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu"))
seqlens_kv = list(self.sequences.values())
cu_seqlens_kv = [0] + [sum(seqlens_kv[:i]) for i in range(1, self.batch_size + 1)]
cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * (
self.max_batch_size - self.batch_size
)
self.cu_seqlens_kv.copy_(torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu"))
def get_seqlens_pre_step(self):
"""Get cached sequence lengths before the stepping"""
return torch.Tensor(list(self.sequences_pre_step.values())).to(
dtype=torch.int32, device="cpu"
)
def convert_paged_to_nonpaged(self, layer_number: int):
"""
Convert k_cache and v_cache from paged to non-paged format.
Parameters
----------
batch_indices : List[int]
Sequence of indices to reorder along the batch dimensions of
the KV cache. Must have a length equal to the batch size.
layer_number: int
Layer number of attention in the model
Returns
-------
k_cache: torch.Tensor
Non-paged key cache tensor
v_cache: torch.Tensor
Non-paged value cache tensor
"""
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
k_cache, v_cache = self.cache_manager.cache[layer_number]
page_table = self.cache_manager.page_table
batch_size = page_table.shape[0]
new_k_cache = rearrange(
k_cache[page_table.flatten()],
"(b npages) page_size ... -> b (npages page_size) ...",
b=batch_size,
)
new_v_cache = rearrange(
v_cache[page_table.flatten()],
"(b npages) page_size ... -> b (npages page_size) ...",
b=batch_size,
)
for layer_number, inference_memory in self.key_value_memory_dict.items():
inference_key_memory, inference_value_memory = inference_memory
assert (
len(batch_indices) == inference_key_memory.shape[1]
) # make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_indices]
new_inference_value_memory = inference_value_memory[:, batch_indices]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory,
new_inference_value_memory,
new_k_cache = new_k_cache[: self.batch_size].contiguous()
new_v_cache = new_v_cache[: self.batch_size].contiguous()
return new_k_cache, new_v_cache
def step(
self,
layer_number: int,
new_k: torch.Tensor,
new_v: torch.Tensor,
qkv_format: str,
):
"""
Copy new KV tokens to the cache.
Parameters
----------
layer_number: int
Layer number of attention in the model
new_k: torch.Tensor
New key tokens for layer_number in current inference iteration
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
qkv_format: str
Format of new_q, new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
Returns
-------
k_cache: torch.Tensor
Full key tensor containing both previous and current key tokens
v_cache: torch.Tensor
Full value tensor containing both previous and current value tokens
cu_seqlens_q: torch.Tensor
Updated cumulative sequence lengths for query, [batch_size + 1]
cu_seqlens_kv: torch.Tensor
Updated cumulative sequence lengths for key and value, [batch_size + 1]
max_seqlen_q: int
Update maximum sequence length for query
max_seqlen_kv: int
Update maximum sequence length for key and value
qkv_format: str
Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step()
"""
self.input_qkv_format = qkv_format
if self.input_qkv_format == self.cache_qkv_format:
self.output_qkv_format = self.cache_qkv_format
else:
self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format
k_cache, v_cache = self.cache_manager.step(
layer_number,
new_k,
new_v,
self.cu_seqlens_q,
self.cu_seqlens_kv,
qkv_format,
)
return (
k_cache,
v_cache,
self.cu_seqlens_q,
self.cu_seqlens_kv,
self.max_seqlen_kv,
self.output_qkv_format,
)
class NonPagedKVCacheManager(KVCacheManager):
"""Non-paged KV cache manager"""
def __init__(
self,
max_batch_size: int,
max_seqlen: int,
num_heads: int,
head_dim_k: int,
dtype: torch.dtype,
head_dim_v: Optional[int] = None,
):
super().__init__()
"""Initialize cache manager"""
self.max_batch_size = max_batch_size
self.max_seqlen = max_seqlen
self.num_heads = num_heads
self.head_dim_k = head_dim_k
self.dtype = dtype
self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k
# track sequences in the cache, {seq_id: seq_len}
self.sequences = OrderedDict()
# cache tensors, cache[layer_number] = (k_cache, v_cache)
self.cache = {}
# track sequence indices in the batch in order to re-index k_cache and v_cache
self.batch_indices = torch.zeros(
self.max_batch_size,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
# after re-indexing, batch indices are always [0, ..., b-1]
self.batch_indices_post_step = torch.range(
0,
self.max_batch_size - 1,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
def allocate_memory(self, layer_number):
"""Allocate memory for the cache"""
k_cache = torch.zeros(
self.max_batch_size,
self.max_seqlen,
self.num_heads,
self.head_dim_k,
dtype=self.dtype,
device=torch.cuda.current_device(),
)
v_cache = torch.zeros(
self.max_batch_size,
self.max_seqlen,
self.num_heads,
self.head_dim_v,
dtype=self.dtype,
device=torch.cuda.current_device(),
)
self.cache[layer_number] = (k_cache, v_cache)
def pre_step(
self,
step_dict: OrderedDict,
):
"""Update tracked sequences and prepare for step()"""
# Track unfinished sequences' indices in the batch, e.g.
# at t-1, seq_ids = [0, 1, 2, 3]; at t, seq_ids = [0, 2, 3] since seq_id 1 is finished
# step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that
# they are contiguous and match the indexing in q
prev_batch_size = len(self.sequences)
unfinished_seqs = self.sequences.keys() & step_dict.keys()
finished_seqs = self.sequences.keys() - unfinished_seqs
unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs]
finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs]
self.batch_indices.copy_(
torch.Tensor(
(
unfinished_indices
+ finished_indices
+ list(range(prev_batch_size, self.max_batch_size))
)
).to(dtype=torch.int32, device="cpu")
)
# Advance unfinished sequences
for i in unfinished_seqs:
self.sequences[i] += 1
# Remove finished sequences
for i in finished_seqs:
self.sequences.pop(i)
# Add new sequences
new_seqs = step_dict.keys() - self.sequences.keys()
for i in new_seqs:
self.sequences[i] = step_dict[i]
return self.sequences
def step(
self,
layer_number,
new_k: torch.Tensor,
new_v: torch.Tensor,
cu_new_seqlens,
cu_cached_seqlens,
qkv_format: str,
):
"""
Copy the new tokens to the non-paged KV cache.
Parameters
----------
layer_number: int
Layer number of attention in the model
new_k: torch.Tensor
New key tokens for layer_number in current inference iteration
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1]
qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
Returns
-------
k_cache: torch.Tensor
Full key tensor containing both previous and current key tokens
v_cache: torch.Tensor
Full value tensor containing both previous and current value tokens
"""
k_cache, v_cache = self.cache[layer_number]
batch_size = self.max_batch_size
ctx_len = 1
if qkv_format == "bshd":
batch_size = new_k.shape[0]
ctx_len = new_k.shape[1]
if qkv_format == "sbhd":
batch_size = new_k.shape[1]
ctx_len = new_k.shape[0]
tex.copy_to_kv_cache(
new_k,
new_v,
k_cache,
v_cache,
self.batch_indices,
cu_new_seqlens,
cu_cached_seqlens,
QKVFormat[qkv_format],
batch_size,
ctx_len,
self.max_seqlen,
1,
True,
)
k_cache = k_cache[:batch_size]
v_cache = v_cache[:batch_size]
return k_cache, v_cache
class Page:
"""A single page"""
def __init__(self, page_id: int):
"""Initialize a page"""
self.page_id = page_id
self.allocated = 0
def allocate_page(self):
"""Allocate a page"""
self.allocated = True
def deallocate_page(self):
"""Deallocate a page"""
self.allocated = False
class PagedKVCacheManager(KVCacheManager):
"""Paged KV cache manager"""
def __init__(
self,
total_num_pages: int,
page_size: int,
num_heads: int,
head_dim_k: int,
dtype: torch.dtype,
max_batch_size: int,
max_seqlen: int,
head_dim_v: Optional[int] = None,
):
super().__init__()
"""Initialize cache manager"""
self.total_num_pages = total_num_pages
self.page_size = page_size
self.num_heads = num_heads
self.head_dim_k = head_dim_k
self.dtype = dtype
self.max_batch_size = max_batch_size
self.max_seqlen = max_seqlen
self.max_pages_per_seq = max_seqlen // self.page_size
self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k
# track sequences in the cache, {seq_id: seq_len}
self.sequences = OrderedDict()
# cache tensors, cache[layer_number] = (k_cache, v_cache)
self.cache = {}
# available pages, [Page(),...]
self.free_pages = []
for i in range(self.total_num_pages):
self.free_pages.append(Page(i))
# allocated pages, {seq_id: [page_id,...]}
self.allocated_pages = defaultdict(list)
# page table, [batch_size, max_pages_per_seq]
self.page_table = torch.zeros(
self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda"
)
def reset(self):
"""Reset cache manager state"""
self.sequences = OrderedDict()
self.free_pages = []
for i in range(self.total_num_pages):
self.free_pages.append(Page(i))
self.allocated_pages = defaultdict(list)
self.page_table.fill_(0)
def allocate_memory(self, layer_number):
"""Allocate memory for the cache"""
k_cache = torch.zeros(
self.total_num_pages,
self.page_size,
self.num_heads,
self.head_dim_k,
dtype=self.dtype,
device=torch.cuda.current_device(),
)
v_cache = torch.zeros(
self.total_num_pages,
self.page_size,
self.num_heads,
self.head_dim_v,
dtype=self.dtype,
device=torch.cuda.current_device(),
)
self.cache[layer_number] = (k_cache, v_cache)
def print_cache(self):
"""Print KV cache status"""
used_pages = [self.get_page_count(seq) for seq in self.sequences]
logger = logging.getLogger("PagedKVCacheManager")
logger.debug("Cache status:")
logger.debug(
" total pages: %s (used %s, free %s)",
self.total_num_pages,
sum(used_pages),
len(self.free_pages),
)
logger.debug(" total sequences: %s", self.get_sequence_count())
for i, seq in enumerate(self.sequences):
logger.debug(
" >> batch index %s: seq_id %s, num_tokens %s, num_pages %s, page_list %s",
i,
seq,
self.get_sequence_lengths()[i],
self.get_page_count(seq),
self.get_page_list(seq),
)
def get_sequence_count(self):
"""Get the total number of sequences in the KV cache"""
return len(self.sequences)
def get_sequence_lengths(self):
"""Get the list of sequence lengths in the KV cache"""
return list(self.sequences.values())
def has_free_page(self) -> bool:
"""Whether the page pool has any free pages left"""
return len(self.free_pages) > 0
def get_page_count(self, seq: int):
"""Get the number of pages allocated to a sequence"""
return len(self.allocated_pages[seq])
def get_page_list(self, seq: int):
"""Get the list of pages allocated to a sequence"""
return [x.page_id for x in self.allocated_pages[seq]]
def get_page_table(self, sequences: List[int]):
"""Get the page table, in shape [batch_size, max_pages_per_seq]"""
page_table = torch.Tensor(
[
self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq))
for seq in sequences
]
).to(dtype=torch.int32, device="cpu")
self.page_table[: self.get_sequence_count()].copy_(page_table)
return self.page_table
def allocate_page(self, seq: int):
"""Allocate a new page to a sequence"""
if not self.has_free_page():
raise RuntimeError("KV cache is full!")
page = self.free_pages.pop(0)
page.allocate_page()
self.allocated_pages[seq].append(page)
def allocate_sequence(self, seq: int, context_len: int):
"""Add a new sequence to the cache"""
num_pages = context_len // self.page_size
if context_len % self.page_size > 0:
num_pages = num_pages + 1
for _ in range(num_pages):
self.allocate_page(seq)
def deallocate_sequence(self, seq: int):
"""Deallocate all the pages for a sequence"""
for page in self.allocated_pages[seq]:
page.deallocate_page()
if not page.allocated:
self.free_pages.append(page)
self.allocated_pages.pop(seq)
def pre_step(
self,
step_dict: OrderedDict,
):
"""Update tracked sequences and prepare for step()"""
# Remove finished sequences and advance unfinished sequences
unfinished_seqs = self.sequences.keys() & step_dict.keys()
finished_seqs = self.sequences.keys() - unfinished_seqs
for seq in finished_seqs:
self.sequences.pop(seq)
self.deallocate_sequence(seq)
for seq in unfinished_seqs:
if self.sequences[seq] % self.page_size == 0 and self.sequences[seq] < self.max_seqlen:
self.allocate_page(seq)
self.sequences[seq] += 1
# Add new sequences
new_seqs = step_dict.keys() - self.sequences.keys()
for seq in new_seqs:
self.sequences[seq] = step_dict[seq]
self.allocate_sequence(seq, step_dict[seq])
# Get page table
self.page_table = self.get_page_table(list(self.sequences.keys()))
return self.sequences
def step(
self,
layer_number: int,
new_k: torch.Tensor,
new_v: torch.Tensor,
cu_new_seqlens,
cu_cached_seqlens,
qkv_format: str,
):
"""
Copy the new tokens to the paged KV cache.
Parameters
----------
layer_number: int
Layer number of attention in the model
new_k: torch.Tensor
New key tokens for layer_number in current inference iteration
new_v: torch.Tensor
New value tokens for layer_number in current inference iteration
cu_new_seqlens: torch.Tensor
Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1]
cu_cached_seqlens: torch.Tensor
Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1]
qkv_format: str
Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'}
Returns
-------
k_cache: torch.Tensor
Full key tensor containing both previous and current key tokens
v_cache: torch.Tensor
Full value tensor containing both previous and current value tokens
"""
k_cache, v_cache = self.cache[layer_number]
batch_size = self.max_batch_size
ctx_len = 1
if qkv_format == "bshd":
batch_size = new_k.shape[0]
ctx_len = new_k.shape[1]
if qkv_format == "sbhd":
batch_size = new_k.shape[1]
ctx_len = new_k.shape[0]
tex.copy_to_kv_cache(
new_k,
new_v,
k_cache,
v_cache,
self.page_table,
cu_new_seqlens,
cu_cached_seqlens,
QKVFormat[qkv_format],
batch_size,
ctx_len,
self.max_seqlen,
self.max_pages_per_seq,
False,
)
return k_cache, v_cache
......@@ -34,6 +34,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
META_O_CP,
META_DQKV_CP,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.constants import TE_DType
......@@ -91,7 +92,6 @@ class FlashAttentionUtils:
Manage Flash Attention versioning information
"""
# Detect flash-attn v2 in the environment
is_installed = False
version = PkgVersion("0")
version_required = PkgVersion("2.1.1")
......@@ -102,21 +102,25 @@ class FlashAttentionUtils:
v2_3_plus = False
v2_4_plus = False
v2_4_1_plus = False
v2_5_plus = False
v2_5_7_plus = False
v2_6_0_plus = False
v2_7_0_plus = False
warning_printed = False
v3_is_installed = False
fa3_version = PkgVersion("0")
v3_0_0_beta = False
use_v3 = False
# TODO(cyang): update FA to 2.7.3 when its FA3 compilation issue is resolved
# https://github.com/Dao-AILab/flash-attention/issues/1452
# FA3 from FA 2.7.3+/hopper has different APIs than FA3 from 2.7.2/hopper
# Please follow these instructions to install FA3
v3_installation_steps = """\
(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(3) mkdir -p $python_path/flashattn_hopper
(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py"""
(1) git clone https://github.com/Dao-AILab/flash-attention.git
(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
(4) mkdir -p $python_path/flash_attn_3
(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py"""
v3_warning_printed = False
@staticmethod
def set_flash_attention_version():
......@@ -129,13 +133,11 @@ class FlashAttentionUtils:
FlashAttentionUtils.v2_3_plus = FlashAttentionUtils.version >= PkgVersion("2.3")
FlashAttentionUtils.v2_4_plus = FlashAttentionUtils.version >= PkgVersion("2.4")
FlashAttentionUtils.v2_4_1_plus = FlashAttentionUtils.version >= PkgVersion("2.4.1")
FlashAttentionUtils.v2_5_plus = FlashAttentionUtils.version >= PkgVersion("2.5.0")
FlashAttentionUtils.v2_5_7_plus = FlashAttentionUtils.version >= PkgVersion("2.5.7")
FlashAttentionUtils.v2_6_0_plus = FlashAttentionUtils.version >= PkgVersion("2.6.0")
FlashAttentionUtils.v2_7_0_plus = FlashAttentionUtils.version >= PkgVersion("2.7.0")
# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
@staticmethod
def set_flash_attention_3_params():
"""
......@@ -145,7 +147,6 @@ class FlashAttentionUtils:
FlashAttentionUtils.v3_0_0_beta = (
PkgVersion("3.0.0b") < FlashAttentionUtils.fa3_version < PkgVersion("3.0.0")
)
FlashAttentionUtils.use_v3 = True
@dataclass(eq=True)
......@@ -203,6 +204,8 @@ class AttentionParams:
Whether `DotProductAttention` is in an `fp8_autocast` region.
fp8_meta: Optional[Dict[str Any]], default = `None`
The FP8 metadata tensor of `DotProductAttention`.
inference_params: Optional[InferenceParams], default = `None`
Inference-related parameters. See InferenceParams for details.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
......@@ -228,6 +231,7 @@ class AttentionParams:
is_training: bool = True
fp8: bool = False
fp8_meta: Union[Dict[str, Any], None] = None
inference_params: Optional[InferenceParams] = None
def __eq__(self, other):
"""
......@@ -298,6 +302,7 @@ def get_attention_backend(
is_training = attention_params.is_training
fp8 = attention_params.fp8
fp8_meta = attention_params.fp8_meta
inference_params = attention_params.inference_params
# Run config
logger = logging.getLogger("DotProductAttention")
......@@ -334,13 +339,19 @@ def get_attention_backend(
# regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is
# necessary for performance/functionality, a warning will be issued to prompt users to
# install an appropriate FA version.
qkv_format, q_format, _ = get_qkv_format(qkv_layout, inference_params)
# Filter: Environment variables
use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1"))
use_flash_attention_2 = use_flash_attention
use_flash_attention_3 = use_flash_attention
flash_attention_backend = None
use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1"))
use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1"))
if not use_flash_attention and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0")
if not use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 due to NVTE_FLASH_ATTN=0")
if not use_flash_attention_3 and FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 due to NVTE_FLASH_ATTN=0")
if not use_fused_attention:
logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0")
if not use_unfused_attention:
......@@ -348,60 +359,124 @@ def get_attention_backend(
# Filter: Compute capability
if device_compute_capability < (8, 0):
if use_flash_attention and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
use_flash_attention = False
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 for compute capability < sm80")
use_flash_attention_2 = False
if use_fused_attention:
logger.debug("Disabling FusedAttention as it requires compute capability sm80+")
logger.debug("Disabling FusedAttention for compute capability < sm80")
use_fused_attention = False
if device_compute_capability < (9, 0):
if use_flash_attention and FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+")
FlashAttentionUtils.use_v3 = False
if device_compute_capability != (9, 0):
if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for compute capability != sm90")
use_flash_attention_3 = False
# Filter: Data type
if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [
if qkv_dtype not in [torch.bfloat16, torch.float16]:
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention 2 for unsupported qkv_dtype = %s. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. ",
qkv_dtype,
)
use_flash_attention_2 = False
if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in [
torch.Tensor,
Float8Tensor,
]:
if use_flash_attention and FlashAttentionUtils.is_installed:
if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed:
logger.debug(
"Disabling FlashAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
"Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, "
"qkv_type = {torch.Tensor, Float8Tensor}. ",
qkv_dtype,
qkv_type,
)
use_flash_attention = False
use_flash_attention_3 = False
if use_fused_attention:
logger.debug(
"Disabling FusedAttention due to unsupported QKV data type. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16}. "
"Found: qkv_dtype = %s.",
"Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. "
"Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, "
"qkv_type = {torch.Tensor, Float8Tensor}. ",
qkv_dtype,
qkv_type,
)
use_fused_attention = False
# Filter: Execution type
if fp8 and fp8_meta["recipe"].fp8_dpa:
if use_flash_attention and not FlashAttentionUtils.use_v3:
if FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8")
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 for FP8 attention")
use_flash_attention_2 = False
if use_flash_attention_3 and is_training:
if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for FP8 training")
use_flash_attention_3 = False
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention")
use_unfused_attention = False
# Filter: KV cache
# backend | precision | KV cache | architecture | qkv_format | page_size
# ---------------------------------------------------------------------------------------
# Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1
# Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256
# Flash v3 | FP16/BF16 | non-paged/paged | sm90 | bshd,sbhd,thd | >= 1
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if inference_params is not None:
if context_parallel:
logger.debug("Disabling all backends for KV caching with context parallelism")
use_flash_attention = False
if use_flash_attention and FlashAttentionUtils.use_v3 and is_training:
logger.debug(
"Disabling FlashAttention as FlashAttention 3 does not support FP8 training"
)
use_fused_attention = False
use_unfused_attention = False
if fp8 and fp8_meta["recipe"].fp8_dpa:
if fp8_meta["recipe"].fp8_mha:
logger.debug("Disabling all backends for KV caching with FP8 MHA")
use_flash_attention = False
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
use_fused_attention = False
use_unfused_attention = False
if use_flash_attention_3 and q_format != "thd":
if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for FP8 KV caching and non-THD")
use_flash_attention_3 = False
if use_fused_attention:
logger.debug("Disabling FusedAttention for FP8 KV caching")
use_fused_attention = False
else:
if q_format == "thd" and pad_between_seqs:
logger.debug("Disabling all backends for pad_between_seqs = True and KV caching")
use_flash_attention = False
use_fused_attention = False
use_unfused_attention = False
if inference_params.is_paged:
if use_flash_attention_2 and inference_params.page_size < 256:
if FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 for page size < 256")
use_flash_attention_2 = False
if use_flash_attention_2:
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.5")
elif not FlashAttentionUtils.v2_5_plus:
logger.debug(
"Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+"
)
use_flash_attention_2 = False
# Filter: Head dimension
if use_flash_attention and head_dim_qk != head_dim_v:
if FlashAttentionUtils.is_installed:
if head_dim_qk != head_dim_v:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed
):
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False
if use_flash_attention and (
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and qkv_layout_group != "hd_hd_hd":
logger.debug(
"Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
qkv_layout,
)
use_fused_attention = False
if use_flash_attention_2 and (
head_dim_qk > 256
or head_dim_qk % 8 != 0
or (
......@@ -411,7 +486,7 @@ def get_attention_backend(
):
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. "
"Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
......@@ -419,23 +494,21 @@ def get_attention_backend(
head_dim_v,
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd":
logger.debug(
"Disabling FusedAttention as MLA is not supported with qkv_layout = %s",
qkv_layout,
)
use_fused_attention = False
use_flash_attention_2 = False
if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128):
if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for head_dim > 128")
use_flash_attention_3 = False
# Filter: QKV layout
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if qkv_format == "thd":
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd")
use_unfused_attention = False
if use_flash_attention and pad_between_seqs:
if FlashAttentionUtils.is_installed:
if pad_between_seqs:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed
):
logger.debug(
"Disabling FlashAttention for qkv_format = thd when there is "
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
......@@ -443,9 +516,9 @@ def get_attention_backend(
use_flash_attention = False
# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention and FlashAttentionUtils.use_v3:
if attention_dropout != 0.0 and use_flash_attention_3:
logger.debug("Disabling FlashAttention 3 for dropout")
FlashAttentionUtils.use_v3 = False
use_flash_attention_3 = False
# Filter: Context parallelism
# qkv_format | attn_mask_type | attn_bias_type | supported backends
......@@ -464,29 +537,26 @@ def get_attention_backend(
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
)
use_unfused_attention = False
if context_parallel and use_flash_attention:
if context_parallel and (use_flash_attention_2 or use_flash_attention_3):
if FlashAttentionUtils.is_installed or FlashAttentionUtils.v3_is_installed:
if fp8 and fp8_meta["recipe"].fp8_dpa:
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with FP8"
)
use_flash_attention = False
if "bottom_right" in attn_mask_type:
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal_bottom_right masking"
)
use_flash_attention = False
elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" causal masking for cross-attention"
)
use_flash_attention = False
elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]:
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with bias"
" type of %s",
......@@ -494,7 +564,6 @@ def get_attention_backend(
)
use_flash_attention = False
elif qkv_format == "thd" and core_attention_bias_type != "no_bias":
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention as it does not support context parallelism with"
" attention bias for THD format"
......@@ -552,61 +621,25 @@ def get_attention_backend(
# arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention
# | [b, h, sq, skv] |
if attn_mask_type == "arbitrary":
if use_flash_attention and FlashAttentionUtils.is_installed:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed
):
logger.debug("Disabling FlashAttention for arbitrary mask")
use_flash_attention = False
if use_fused_attention:
logger.debug("Disabling FusedAttention for arbitrary mask")
use_fused_attention = False
if (
use_flash_attention
and FlashAttentionUtils.use_v3
(use_flash_attention_2 or use_flash_attention_3)
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
logger.warning(
"Disabling FlashAttention 3 as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
FlashAttentionUtils.use_v3 = False
if (
use_flash_attention
and attn_mask_type in ["causal", "padding_causal"]
and max_seqlen_q != max_seqlen_kv
):
if FlashAttentionUtils.v2_1_plus:
logger.warning(
"Disabling FlashAttention as it only supports bottom-right-diagonal "
"causal mask since flash-attn 2.1. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.max_version = PkgVersion("2.1")
if (
use_flash_attention
and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"]
and max_seqlen_q != max_seqlen_kv
):
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.1")
elif not FlashAttentionUtils.v2_1_plus and not FlashAttentionUtils.use_v3:
logger.warning(
"Disabling FlashAttention as it only supports top-left-diagonal "
"causal mask before flash-attn 2.1. See "
"causal mask since flash-attn 2.1 (our minimum supported version). See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
if (
use_flash_attention
and FlashAttentionUtils.use_v3
and fp8
and fp8_meta["recipe"].fp8_dpa
and "padding" in attn_mask_type
):
logger.debug("Disabling FlashAttention 3 for FP8 and padding masks")
FlashAttentionUtils.use_v3 = False
# Filter: Sliding window attention
# backend | window_size | diagonal alignment
......@@ -637,19 +670,14 @@ def get_attention_backend(
"with s_q > s_kv for cross-attention"
)
use_fused_attention = False
if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if FlashAttentionUtils.use_v3:
logger.debug(
"Disabling FlashAttention 3 as it does not support sliding window attention"
)
FlashAttentionUtils.use_v3 = False
if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.3")
elif not FlashAttentionUtils.v2_3_plus:
logger.debug(
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
use_flash_attention = False
use_flash_attention_2 = False
# Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment
......@@ -660,21 +688,25 @@ def get_attention_backend(
# | | bottom_right (converts to a 'post_scale_bias' bias)
# UnfusedDotProductAttention | no_bias, pre/post_scale_bias |
# | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias
if use_flash_attention and core_attention_bias_type == "alibi":
if FlashAttentionUtils.use_v3:
if core_attention_bias_type == "alibi":
if use_flash_attention_3:
if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for ALiBi")
FlashAttentionUtils.use_v3 = False
use_flash_attention_3 = False
if use_flash_attention_2:
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.4")
elif not FlashAttentionUtils.v2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention = False
use_flash_attention_2 = False
if use_flash_attention and (
if (
core_attention_bias_type not in ["no_bias", "alibi"]
or core_attention_bias_shape is not None
):
if FlashAttentionUtils.is_installed:
if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or (
use_flash_attention_3 and FlashAttentionUtils.v3_is_installed
):
logger.debug("Disabling FlashAttention for pre/post_scale_bias")
use_flash_attention = False
......@@ -779,16 +811,16 @@ def get_attention_backend(
# | otherwise: no
# sub-backend 2 | no
# UnfusedDotProductAttention | yes
if use_flash_attention and deterministic:
if use_flash_attention_2 and deterministic:
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.4.1")
elif not FlashAttentionUtils.v2_4_1_plus and not FlashAttentionUtils.use_v3:
elif not FlashAttentionUtils.v2_4_1_plus:
logger.warning(
"Disabling FlashAttention as version <2.4.1 does not support deterministic "
"execution. To use FlashAttention with deterministic behavior, "
"please install flash-attn >= 2.4.1."
)
use_flash_attention = False
use_flash_attention_2 = False
if use_fused_attention and deterministic:
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons")
......@@ -805,29 +837,58 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention for determinism reasons")
use_fused_attention = False
# All available backends
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
# use_flash_attention may have been set above
use_flash_attention_2 = use_flash_attention and use_flash_attention_2
use_flash_attention_3 = use_flash_attention and use_flash_attention_3
# `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`.
# When `FusedAttention` does not support the provided attention params, and `FlashAttention`
# does, we recommend users to install flash-attn if not installed already.
if not use_fused_attention and use_flash_attention and not FlashAttentionUtils.is_installed:
if not use_fused_attention and _NVTE_FLASH_ATTN:
if (
use_flash_attention_3
and not FlashAttentionUtils.v3_is_installed
and not FlashAttentionUtils.v3_warning_printed
and torch.cuda.current_device() == 0
):
logger.warning(
"flash-attn v3 may provide important feature support or performance improvement."
" Please install flash-attn v3 by \n%s",
FlashAttentionUtils.v3_installation_steps,
)
FlashAttentionUtils.v3_warning_printed = True
elif (
use_flash_attention_2
and not FlashAttentionUtils.is_installed
and not FlashAttentionUtils.warning_printed
and torch.cuda.current_device() == 0
):
logger.warning(
"flash-attn may provide important feature support or performance improvement."
" Please install flash-attn %s.",
" Please install flash-attn %s by pip3 install flash-attn==<version>.",
_get_supported_versions(
FlashAttentionUtils.version_required,
FlashAttentionUtils.max_version,
),
)
if use_flash_attention and not FlashAttentionUtils.is_installed:
use_flash_attention = False
available_backends[0] = False
FlashAttentionUtils.warning_printed = True
# All available backends
if use_flash_attention_2 and not FlashAttentionUtils.is_installed:
use_flash_attention_2 = False
if use_flash_attention_3 and not FlashAttentionUtils.v3_is_installed:
use_flash_attention_3 = False
use_flash_attention = use_flash_attention_2 or use_flash_attention_3
available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention]
if use_flash_attention_2:
flash_attention_backend = FlashAttentionUtils.version
if use_flash_attention_3:
flash_attention_backend = FlashAttentionUtils.fa3_version
logger.debug(
"Available backends = {FlashAttention=%s, FusedAttention=%s%s,"
"Available backends = {FlashAttention=%s%s, FusedAttention=%s%s,"
" UnfusedDotProductAttention=%s}",
bool(available_backends[0]),
(f" ({str(flash_attention_backend)})" if flash_attention_backend is not None else ""),
bool(available_backends[1]),
(
f" (sub-backend {int(fused_attention_backend)})"
......@@ -838,28 +899,12 @@ def get_attention_backend(
)
# Select FusedAttention for performance
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
):
if device_compute_capability >= (9, 0):
if use_flash_attention and use_fused_attention and device_compute_capability >= (9, 0):
logger.debug(
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
"for performance reasons"
)
use_flash_attention = False
if (
use_flash_attention
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["FP8"]
and FlashAttentionUtils.use_v3
):
logger.debug(
"Disabling FlashAttention 3 to give FusedAttention preference for performance reasons "
"in FP8 execution"
)
use_flash_attention = False
# Selected backend
if use_flash_attention:
......@@ -869,22 +914,16 @@ def get_attention_backend(
use_unfused_attention = False
selected_backend = "NoBackend"
if use_flash_attention:
selected_backend = "FlashAttention"
selected_backend = f"FlashAttention ({str(flash_attention_backend)})"
elif use_fused_attention:
selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})"
elif use_unfused_attention:
selected_backend = "UnfusedDotProductAttention"
logger.debug("Selected backend = %s", selected_backend)
"""global _attention_backends
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False"""
return (
use_flash_attention,
flash_attention_backend,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
......@@ -892,6 +931,49 @@ def get_attention_backend(
)
@torch.no_grad()
def get_padding_mask(
batch_size: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_q: int,
max_seqlen_kv: int,
):
"""Convert cu_seqlens to attention_mask"""
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(batch_size):
attention_mask_q = torch.cat(
[
attention_mask_q,
torch.Tensor([False] * seqlens_q[i] + [True] * (max_seqlen_q - seqlens_q[i]))
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask_kv = torch.cat(
[
attention_mask_kv,
torch.Tensor([False] * seqlens_kv[i] + [True] * (max_seqlen_kv - seqlens_kv[i]))
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = (
attention_mask_q.to(device="cuda"),
attention_mask_kv.to(device="cuda"),
)
return attention_mask
@torch.no_grad()
def get_full_mask(
max_seqlen_q: int,
......@@ -1400,11 +1482,46 @@ class UnpackTensor(torch.autograd.Function):
return None, None, _pack_tensor(indices, grad_output)
def get_qkv_format(
qkv_layout: str = "bshd_bshd_bshd",
inference_params: InferenceParams = None,
) -> str:
"""Get qkv format.
Parameters
----------
qkv_layout: str
Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details.
inference_params: InferenceParams, default = `None`
InferenceParams related to KV caching.
Returns
----------
qkv_format: str, default = `sbhd`
Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}.
q_format: str
Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str
Format of the `k` and `v` tensors, {`bshd`, `sbhd`, `thd`}.
"""
splited = qkv_layout.replace("paged_kv_", "").split("_")
if inference_params is not None:
q_format = "".join([i for i in splited[0] if i.isalpha()])
kv_format = "".join([i for i in splited[1] if i.isalpha()])
qkv_format = q_format + "_2" + kv_format if q_format != kv_format else q_format
else:
qkv_format = "".join([i for i in splited[0] if i.isalpha()])
q_format = qkv_format
kv_format = qkv_format
return qkv_format, q_format, kv_format
def get_qkv_layout(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qkv_format: str = "sbhd",
inference_params: InferenceParams = None,
) -> str:
"""Get qkv layout.
......@@ -1421,20 +1538,33 @@ def get_qkv_layout(
the sequence length dimension, `b` batch size, `h` the number of attention heads,
`d` head size, and `t` the total number of tokens in a batch, i.e.
`t = sum(s_i) for i = 0...b-1`.
inference_params: InferenceParams, default = `None`
InferenceParams related to KV caching.
Returns
----------
qkv_layout: str
Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five
memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk
of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means
`q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v`
are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and
`v = kv[:,:,:,1,:]`.
Memory layout of `q`, `k` and `v`. Each `qkv_layout` maps to a pair of `q_format` and
`kv_format` in {`bshd`, `sbhd`, `thd`}. The `paged_kv_` prefix is used to indicate that
paged KV caching is in play. A few examples of the layouts are as follows.
(1) `sb3hd` means `q`, `k`, `v` are created as one chunk of memory and that they are
interleaved in the `2`nd dimension. (2) `sbhd_sbh2d` means `q` and `kv` are created in
two chunks and that `q` itself is contiguous and `k`, `v` are interleaved with each other
in the `3`rd dimension, `k = kv[:,:,:,0,:]` and `v = kv[:,:,:,1,:]`. `q_format` and
`kv_format` in this case are still both `sbhd`. (3) `paged_kv_thd_bshd_bshd` means `q` is
created in `thd` and `k`, `v` are in `sbhd`. This is likely due to the cache format in
paged KV caching.
Mapping:
`sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`}
`bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`}
`sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`, `paged_kv_sbhd_sbhd_sbhd`}
`bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`, `paged_kv_bshd_bshd_bshd`}
`thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`}
`sbhd_2bshd`: {`sbhd_bshd_bshd`, `paged_kv_sbhd_bshd_bshd`}
`bshd_2sbhd`: {`bshd_sbhd_sbhd`, `paged_kv_bshd_sbhd_sbhd`}
`thd_2bshd`: {`thd_bshd_bshd`, `paged_kv_thd_bshd_bshd`}
`thd_2sbhd`: {`thd_sbhd_sbhd`, `paged_kv_thd_sbhd_sbhd`}
q: torch.Tensor
Query tensor. It may be different from input `q` as we try to fit tensors to
a supported layout.
......@@ -1444,10 +1574,21 @@ def get_qkv_layout(
v: torch.Tensor
Value tensor. It may be different from input `v` as we try to fit tensors to
a supported layout.
q_format: str
Format of the query tensor, {`bshd`, `sbhd`, `thd`}.
kv_format: str
Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}.
"""
check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v])
assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!"
if "_2" in qkv_format:
q_format, kv_format = qkv_format.split("_2")
is_same_q_kv_format = False
else:
q_format = qkv_format
kv_format = qkv_format
is_same_q_kv_format = True
def run_iteratively(q, k, v):
# check data pointers
......@@ -1534,7 +1675,10 @@ def get_qkv_layout(
# three chunks of memory, q, k and v, which may be disjoint or consecutive, and
# when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
# check_ptrs_qk=True or check_ptrs_kv=True
if is_same_q_kv_format:
qkv_layout = "_".join(list([qkv_format]) * 3)
else:
qkv_layout = q_format + "_" + kv_format + "_" + kv_format
else:
qkv_layout = "not_supported"
......@@ -1548,7 +1692,10 @@ def get_qkv_layout(
if qkv_layout == "not_supported":
raise RuntimeError("The provided qkv memory layout is not supported!")
return qkv_layout, q, k, v
if inference_params is not None and inference_params.is_paged:
qkv_layout = "paged_kv_" + qkv_layout
return qkv_layout, q, k, v, q_format, kv_format
def check_set_window_size(
......
......@@ -91,6 +91,14 @@ def _make_graphed_callables(
sample_args = (sample_args,)
sample_kwargs = (sample_kwargs,)
# Check training/inference
is_training = all(c.training for c in callables)
if not is_training and any(c.training for c in callables):
assert False, (
"make_graphed_callables only supports when modules are all in training or all in"
" inference mode."
)
# Check sizes of args
if _order is None:
assert len(sample_args) == len(callables)
......@@ -255,6 +263,7 @@ def _make_graphed_callables(
outputs, _ = _tree_flatten(func(*args, **kwargs))
for hook in hooks:
hook.remove()
if is_training:
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
......@@ -262,6 +271,8 @@ def _make_graphed_callables(
only_inputs=True,
allow_unused=allow_unused_input,
)
else:
grad_inputs = None
del outputs, grad_inputs
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
......@@ -314,6 +325,7 @@ def _make_graphed_callables(
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
......@@ -329,7 +341,7 @@ def _make_graphed_callables(
static_grad_inputs = []
grad_idx = 0
for arg in static_input_surface:
if arg.requires_grad:
if is_training and isinstance(arg, torch.Tensor) and arg.requires_grad:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
......@@ -366,6 +378,7 @@ def _make_graphed_callables(
static_grad_outputs = tuple(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
......@@ -381,7 +394,7 @@ def _make_graphed_callables(
static_grad_inputs = []
grad_idx = 0
for arg in static_input_surface:
if arg.requires_grad:
if is_training and isinstance(arg, torch.Tensor) and arg.requires_grad:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
......@@ -422,7 +435,10 @@ def _make_graphed_callables(
# Copy values from new tensors into static tensors
for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
if (
isinstance(static_input_surface[i], torch.Tensor)
and static_input_surface[i].data_ptr() != inputs[i].data_ptr()
):
static_input_surface[i].copy_(inputs[i])
# Replay forward graph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment