Unverified Commit 4f33ece4 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add KV cache for paged/non-paged attention (#1355)



* add paged attention; test_kv_cache_accuray and test_paged_attn pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unnecessary change from last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test_fused_attn pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove unnecessary import in test_numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add license for test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add to L0 test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update license for test_paged_attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update kv_cache_manager license
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix build issue from previous merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: minor fix/preparation for inference/cuda graph
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, bshd/sbhd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, thd, no CG
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: non-paged, thd, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, using paged kernel
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: restructure kernels
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: paged, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: padding + BRCM
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: restructure IP, clean up
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix non-CG, fused
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix last commit
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: unfused, non-CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash-attn, non-CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash_attn_with_kvcache
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* commit two files missed by bcef6b34
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: thd_bshd_bshd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix last commit
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix 1c31b68d
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: add bshd_2sbhd, sbhd_2bshd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: some cleanup
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: all qkv_format combinations and merge CM files
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: some lint fixes
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: add docstring for IP
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix sequences_pre
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: minor fixes for multi-layer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: initial multi-layer test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: switch to flash_attn_varlen_func
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix unfused for separate q/kv format
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix fused for separate q/kv formats
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash attn + TELayer + 2 layers
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: unfused + TL + 2layers
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: all modules/backend
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: minor cleanup
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: FlashAttention on Hopper with 2.7.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: FlashAttention + v3 from 39e7179
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: FlashAttention + v3 + FP8 + WIP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add backend support table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: separate use_flash_attention_2 and _3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: tweaks to paged attn script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: enable/disable certain cases for fused attn
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: small fixes for lint and cg
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: minor fixes for attn/infer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix CP
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: readd page info to FADescriptor_v1
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweak to test_numerics.py
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix 9.5/9.7 sq/skv + mask logic
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* clean up
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more minor fixes for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test page_size=1 for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix t3hd/th3d strides
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ckpt recompute and fa3 k_scale
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* raise dynamo recompile limit for test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove thunder test from L0
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix FA selection logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FA3 q_descale shape
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove page_table from IP.step() returns
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FP8 FlashAttn DPA fp8_dpa tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweaks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA3 note and L3 test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove redundant import in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adopt new FA3 APIs from FA2.7.3+/hopper for CP and non-CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* relax tols for TransformerLayers
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix merge 2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FA import comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* relax tols for Ampere
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fa3 version and reduce messaging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA3 to its latest commit on main
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add default values to IP and assertion to graph.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more comments in attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use custom_cache_manager instead of cache_manager
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 05f6a691
......@@ -174,16 +174,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype"));
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"),
py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
py::call_guard<py::gil_scoped_release>());
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
......@@ -191,6 +183,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding",
py::call_guard<py::gil_scoped_release>());
// attention kernels
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
m.def("copy_to_kv_cache", &copy_to_kv_cache, "Copy new KV tokens to KV cache");
m.def("convert_thd_to_bshd", &convert_thd_to_bshd, "Convert a tensor from THD to BSHD");
m.def("convert_bshd_to_thd", &convert_bshd_to_thd, "Convert a tesnor from BSHD to THD");
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD",
py::call_guard<py::gil_scoped_release>());
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment